Naive Bayes を Java で実装してみた
今回も元ネタはこの本を使います。やっぱり Weka で実行すると
するりと実行できてしまって頭に残らないので、Java で実装する
ことを考えます。

- 作者: 荒木雅弘
- 出版社/メーカー: 森北出版
- 発売日: 2014/03/29
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (4件) を見る
package com.minosys.aplication; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; /** * 単純ベイジアン検出器 * * L(D) = arg[i] max P(\Omega_i) * ΠP(x_j|Omega_i) * ただし arg[i] max は引数が最大となるクラス値 i を取り出すことを示す。 * P(A) は集合 A の事前確率、P(x_j|A) はサンプルの入力ベクトルを * x = {x_1, x_2, ..., x_n } と表現した場合に j 番目の属性が x_j となる * 事後確率を表すものとする。 * * @author minosys * */ public class NaiveBayes implements Classifier { public static class NumValueWithProbability { int count; double probability; public String toString() { return count + "(" + probability + ")"; } } // 頻度および出現確率は [判定値, 属性名, 属性値] のマップとなる Map<String, Map<String, Map<String, NumValueWithProbability>>> concept; Map<String, Double> zappMap; /** * 判定属性が特定の判定値を持つ状況下で、指定した属性が出現する回数+Nを返す。 * ゼロ頻度問題を回避するため、各 count 値は +1 されていることに注意。 * Σ_{aname} C_{aname} は判定値の場合の和とはならず、判定値の取り得る * 数だけ大きい。 * @param zvalue * @param aname * @return */ private int calcTotalCount(String zvalue, String aname) { int count = 0; Map<String, NumValueWithProbability> map = concept.get(zvalue).get(aname); for (Iterator<Map.Entry<String, NumValueWithProbability>> i = map.entrySet().iterator(); i.hasNext(); ) { count += i.next().getValue().count; } return count; } /** * 各属性に対する出現確率を更新する。 * 必ず analyze() の最後に呼び出すこと。 */ private void updateProbability() { int totalCount = 0; zappMap = new HashMap<String, Double>(); // 各判定クラスに対して for (Iterator<Map.Entry<String, Map<String, Map<String, NumValueWithProbability>>>> i = concept.entrySet().iterator(); i.hasNext(); ) { Map.Entry<String, Map<String, Map<String, NumValueWithProbability>>> e1 = i.next(); Map<String, Map<String, NumValueWithProbability>> map = e1.getValue(); int count = 0; // 各属性に対して for (Iterator<Map.Entry<String, Map<String, NumValueWithProbability>>> j = map.entrySet().iterator(); j.hasNext(); ) { Map.Entry<String, Map<String, NumValueWithProbability>> e2 = j.next(); int total = calcTotalCount(e1.getKey(), e2.getKey()); // 属性の取り得る値の集合の要素数 + 属性数 assert(total > 0); // 各属性値に対して for (Iterator<Map.Entry<String, NumValueWithProbability>> k = e2.getValue().entrySet().iterator(); k.hasNext(); ) { Map.Entry<String, NumValueWithProbability> e3 = k.next(); count += e3.getValue().count - 1; // 出現数 e3.getValue().probability = (double)e3.getValue().count / (double) total; } } totalCount += count; zappMap.put(e1.getKey(), (double)count); } if (totalCount > 0) { // 判定値の出現割合を計算する for (Iterator<String> x = concept.keySet().iterator(); x.hasNext(); ) { String s = x.next(); zappMap.put(s, zappMap.get(s) / (double)totalCount); } } } @Override public void analyze(GenericData data) { // TODO 自動生成されたメソッド・スタブ concept = new HashMap<String, Map<String, Map<String, NumValueWithProbability>>>(); // マトリックスを初期化する String zname = data.attributeNameList.get(data.lastIindex()); for (Iterator<String> zp = data.attributes.get(zname).iterator(); zp.hasNext(); ) { String z = zp.next(); Map<String, Map<String, NumValueWithProbability>> amap = new HashMap<String, Map<String, NumValueWithProbability>>(); for (Iterator<Map.Entry<String, HashSet<String>>> akp = data.attributes.entrySet().iterator(); akp.hasNext(); ) { Map.Entry<String, HashSet<String>> e2 = akp.next(); if (e2.getKey().equals(zname)) { // 判定名は含めない continue; } Map<String, NumValueWithProbability> vmap = new HashMap<String, NumValueWithProbability>(); for (Iterator<String> avp = e2.getValue().iterator(); avp.hasNext(); ) { NumValueWithProbability np = new NumValueWithProbability(); np.count = 1; String av = avp.next(); vmap.put(av, np); } amap.put(e2.getKey(), vmap); } concept.put(z, amap); } // 各属性名に対して for (Iterator<Map.Entry<String, Integer>> i = data.attributeNameIndex.entrySet().iterator(); i.hasNext(); ) { Map.Entry<String, Integer> e = i.next(); // 属性名が判定名である場合は計算対象から外す(二重計算するのを回避するため) if (e.getValue() == data.lastIindex()) { continue; } // 各行を分類する for (int j = 0; j < data.data.size(); ++j) { String aval = data.get(j, e.getValue()); // 属性値 String zval = data.get(j, data.lastIindex()); // 判定値 NumValueWithProbability np = concept.get(zval).get(e.getKey()).get(aval); ++np.count; } } // 出現確率を計算する updateProbability(); } @Override public ValueWithProbability decide(List<String> slist, Map<String, Integer> amap) { double maxval = 0.0; String maxname = null; Map<String, Double> pr2 = new HashMap<String, Double>(); double px1 = 0.0; // P(x1|z_y)*...*P(z_y) を計算する for (Iterator<Map.Entry<String, Double>> i = zappMap.entrySet().iterator(); i.hasNext(); ) { Map.Entry<String, Double> e = i.next(); double pr = e.getValue(); for (Iterator<Map.Entry<String, Integer>> j = amap.entrySet().iterator(); j.hasNext(); ) { Map.Entry<String, Integer> e2 = j.next(); if (concept.get(e.getKey()).containsKey(e2.getKey())) { pr = pr * concept.get(e.getKey()).get(e2.getKey()).get(slist.get(e2.getValue())).probability; } } px1 = px1 + pr; pr2.put(e.getKey(), pr); } // P(z_y|x1) を Bayes 定理から計算する for (Iterator<Map.Entry<String, Double>> j = pr2.entrySet().iterator(); j.hasNext(); ) { Map.Entry<String, Double> e = j.next(); double pr = e.getValue() / px1; if (pr > maxval) { maxval = pr; maxname = e.getKey(); } } if (maxname != null) { ValueWithProbability v = new ValueWithProbability(maxname, maxval); return v; } return null; } }
NaiveBayes インスタンスに含まれる concept のキーは <クラス値>,
<特徴量名称>, <特徴値> の組となります。
またクラス値が現れる事前確率を zappMap に記録しています。
ゼロ頻度問題を回避するため、concept の値は 1 からカウントするようにしています。
weather.nominal.arff に対してプログラムを実行した結果は以下の通り。
no(0.35714285714285715) yes(0.6428571428571429) -------------------- outlook rainy | 3 4 overcast | 1 5 sunny | 4 3 temperature mild | 3 5 cool | 2 4 hot | 3 3 humidity normal | 2 7 high | 5 4 windy TRUE | 4 4 FALSE | 3 7 << make a decision for the first 10 samples >> [sunny, hot, high, FALSE, no]:no(0.687969069820332) [sunny, hot, high, TRUE, no]:no(0.8372543592582342) [overcast, hot, high, FALSE, yes]:yes(0.7514719978091196) [rainy, mild, high, FALSE, yes]:yes(0.573353879907018) [rainy, cool, normal, FALSE, yes]:yes(0.8758578235790337) [rainy, cool, normal, TRUE, no]:yes(0.7514719978091196) [overcast, cool, normal, TRUE, yes]:yes(0.9189551239115875) [sunny, mild, high, FALSE, no]:no(0.5695010982114841) [sunny, cool, normal, FALSE, yes]:yes(0.7987358616101131) [rainy, mild, normal, FALSE, yes]:yes(0.854638487208009)
1例のみ正しくない判定がなされるが、これは Weka の結果と同じ。
お手本と算出された確率が微妙に違うんだが、まあいいか。