徒然なる日々を送るソフトウェアデベロッパーの記録(2)

技術上思ったことや感じたことを気ままに記録していくブログです。さくらから移設しました。

Ripper アルゴリズムを実装してみた

機械学習の本を読み始めて2週間。まだ3章が終わったところ。

フリーソフトではじめる機械学習入門

フリーソフトではじめる機械学習入門

知識を獲得するアルゴリズムである Ripper を Weka で実行する
例が書いてあるが、Weka だと今一つ頭に入ってこなかったので、
Java で実装してみる。(備忘録です)

package com.minosys.aplication;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Ripper アルゴリズム:
 * IF (条件式) THEN (判定代数) = (判定値)
 * という規則を並べて知識を表現する。
 * 例:
 * if (A = a1) and (B = b2) THEN Z = z3
 * if (C = c4) THEN Z = z4
 * Z=Z5
 *
 * まず、Ci = {x|Z=Zi}としたとき、#Ci がもっとも小さくなる判定値が述語として選択される。
 * 規則の分類には、以下の式が最も大きくなるものを選択する。
 * RuleEval(N', N) = s * (log N'+ / N' - log N+ / N)
 * ただし、N, N+ は規則 Rule([R]) における事例数、注目事例数、N', N'+ は新たに追加したルール Rule(R', [R]) における事例数、注目事例数とする
 * s は Rule([R])およびRule(R', [R]) で注目事例となる事例数。
 *
 * 結果は concept メンバーに返される。
 * @author minosys
 *
 */
public class Ripper {
	public List<RuleIfThen> concept; // 構築するインスタンス

	/**
	 * partialSet を判定条件によって分類する
	 * @param data
	 * @param partialSet
	 * @return
	 */
	private Map<String, ArrayList<Integer>> classifyResult(GenericData data, ArrayList<Integer> partialSet) {
		Map<String, ArrayList<Integer>> map = new HashMap<String, ArrayList<Integer>>();
		for (Iterator<Integer> i = partialSet.iterator(); i.hasNext(); ) {
			int lineno = i.next();
			String v = data.get(lineno, data.lastIindex());
			ArrayList<Integer> list = map.get(v);
			if (list == null) {
				// 登録されていなかったら新しく作成する
				list = new ArrayList<Integer>();
				map.put(v, list);
			}
			list.add(lineno);
		}
		return map;
	}

	/**
	 * 全ての行を要素に追加し、全体集合を構成する
	 * @param data
	 * @return
	 */
	private ArrayList<Integer> createWholeSet(GenericData data) {
		ArrayList<Integer> list = new ArrayList<Integer>();
		for (int i = 0; i < data.data.size(); ++i) {
			list.add(i);
		}
		return list;
	}

	/**
	 * Map に含まれる要素数を全て足し合わせる
	 * @param map
	 * @return
	 */
	private static int countMap(Map<String, ArrayList<Integer>> map) {
		int count = 0;
		for (Iterator<Map.Entry<String, ArrayList<Integer>>> i = map.entrySet().iterator(); i.hasNext(); ) {
			count += i.next().getValue().size();
		}
		return count;
	}

	/**
	 * src マップから target マップを削除する
	 * @param src
	 * @param target
	 */
	private void subtract(Map<String, ArrayList<Integer>> src, Map<String, ArrayList<Integer>> target) {
		ArrayList<String> emptyKeys = new ArrayList<String>();
		for (Iterator<Map.Entry<String, ArrayList<Integer>>> i = src.entrySet().iterator(); i.hasNext(); ) {
			Map.Entry<String, ArrayList<Integer>> e = i.next();
			ArrayList<Integer> targetSet = target.get(e.getKey());
			if (targetSet != null) {
				e.getValue().removeAll(targetSet);
				if (e.getValue().isEmpty()) {
					emptyKeys.add(e.getKey());
				}
			}
		}
		for (Iterator<String> j = emptyKeys.iterator(); j.hasNext(); ) {
			src.remove(j.next());
		}
	}

	/**
	 * 仮説追加判定をするための判定関数
	 *
	 * @param data
	 * @param mapAfter
	 * @param mapBefore
	 * @return
	 */
	private double calcEvaluator(GenericData data, int acount, int acountAll, int count, int countAll) {
		double score = 0.0;
		// 仮説を設定する前の評価値を計算する; 評価値に -1 を掛けていることに注意
		// before: 評価関数第2項を計算
		if (count > 0) {
			score = - Math.log((double)count/(double)countAll);
		}
		// after: 全体数を計算
		if (acount > 0) {
			score = (double)acount * (Math.log((double)acount / (double)acountAll) + score) / Math.log(2);
		} else {
			score = 0.0;
		}
		return score;
	}

	// 評価関数が最大となる属性名と属性値の組を返す
	private RuleIfThen.KeyValuePair getMaxPair(Map<String, Map<String, Map<String, ArrayList<Integer>>>> map,
			GenericData data,
			ArrayList<Integer> remainSet,
			Map<String, HashSet<String>> attributes,
			Map<String, ArrayList<Integer>> beforeMap,
			String decision) {
		RuleIfThen.KeyValuePair pair = null;
		double eval = -Double.MAX_VALUE;

		// 反対対象がない場合は何もしない
		if (decision == null) {
			return null;
		}

		int decisionCount = Integer.MAX_VALUE;
		// If Then ルールの条件部への追加: 述語は決定されている
		ArrayList<Integer> list = beforeMap.get(decision);
		if (list == null) {
			// リストが含まれていない場合はこれ以上探索しない
			return null;
		}
		decisionCount = beforeMap.get(decision).size();

		for (Iterator<Map.Entry<String, HashSet<String>>> i = attributes.entrySet().iterator(); i.hasNext(); ) {
			Map<String, Map<String, ArrayList<Integer>>> pmap = new HashMap<String, Map<String, ArrayList<Integer>>>();
			Map.Entry<String, HashSet<String>> ea = i.next();
			for (Iterator<String> j = ea.getValue().iterator(); j.hasNext(); ) {
				String v = j.next();
				int acount = 0;
				int acountAll = 0;
				Map<String, ArrayList<Integer>> zmap = new HashMap<String, ArrayList<Integer>>();
				for (Iterator<Integer> k = remainSet.iterator(); k.hasNext(); ) {
					Integer ke = k.next();
					if (!data.get(ke, ea.getKey()).equals(v)) {
						// 指定された属性が、指定された属性値でない場合はカウントしない
						continue;
					}
					String zvalue = data.get(ke, data.lastIindex());
					if (decision.equals(zvalue)) {
						++acount;
					}
					ArrayList<Integer> alist = zmap.get(zvalue);
					if (alist == null) {
						alist = new ArrayList<Integer>();
						zmap.put(zvalue, alist);
					}
					alist.add(ke);
					++acountAll;
				}
				if (zmap.size() > 0) {
					pmap.put(v, zmap);
				}
				double score = calcEvaluator(data, acount, acountAll, decisionCount, countMap(beforeMap));
				if (score > eval) {
					eval = score;
					pair = new RuleIfThen.KeyValuePair(ea.getKey(), v);
				}
			}
			if (pmap.size() > 0) {
				map.put(ea.getKey(), pmap);
			}
		}
		return pair;
	}

	// 要素数が最も大きい key を返す
	private String retrieveMostLikely(Map<String, ArrayList<Integer>> map) {
		String r = null;
		int count = -1;
		if (map == null) {
			return null;
		}
		for (Iterator<Map.Entry<String, ArrayList<Integer>>> i = map.entrySet().iterator(); i.hasNext(); ) {
			Map.Entry<String, ArrayList<Integer>> e = i.next();
			int c = e.getValue().size();
			if (c > count) {
				count = c;
				r = e.getKey();
			}
		}
		return r;
	}

	/**
	 * マップに所属するライン番号を合算して返す
	 *
	 * @param map
	 * @return
	 */
	private ArrayList<Integer> createRemainSet(Map<String, ArrayList<Integer>> map) {
		ArrayList<Integer> list = new ArrayList<Integer>();
		if (map != null) {
			for (Iterator<Map.Entry<String, ArrayList<Integer>>> i = map.entrySet().iterator(); i.hasNext(); ) {
				list.addAll(i.next().getValue());
			}
		}
		return list;
	}

	public void analyze(GenericData data) {
		// 全体集合を求める
		ArrayList<Integer> remainSet = createWholeSet(data);
		Map<String, ArrayList<Integer>> beforeMap = classifyResult(data, remainSet);
		concept = new ArrayList<RuleIfThen>();

		// 取り得る属性値の全体集合を求める
		Map<String, HashSet<String>> attributes = data.cloneAttributes();
		// ただし、最後の属性は判定に使うので、ルールを求めるときには使用しない
		attributes.remove(data.attributeNameList.get(data.attributeNameList.size() - 1));

		while (remainSet.size() > 0) {
			Set<RuleIfThen.KeyValuePair> pairSet = new HashSet<RuleIfThen.KeyValuePair>();
			ArrayList<Integer> copiedSet = new ArrayList<Integer>(remainSet);
			Map<String, ArrayList<Integer>> copiedMap = new HashMap<String, ArrayList<Integer>>(beforeMap);
			String zvalue = null;
			int zCount = Integer.MAX_VALUE;
			// 最も要素数の少ない判定値を算出する
			for (Iterator<Map.Entry<String, ArrayList<Integer>>> ie = beforeMap.entrySet().iterator(); ie.hasNext(); ) {
				Map.Entry<String, ArrayList<Integer>> e = ie.next();
				int c = e.getValue().size();
				if (zCount > c) {
					zCount = c;
					zvalue = e.getKey();
				}
			}
			if (zCount == countMap(beforeMap)) {
				// 単一の判定値しか含まないので分類は終了
				RuleIfThen rule = new RuleIfThen(zvalue, zCount, zCount);
				concept.add(rule);
				break;
			}

			while (true) {
				Map<String, Map<String, Map<String, ArrayList<Integer>>>> afterMap = new HashMap<String, Map<String, Map<String, ArrayList<Integer>>>>();

				RuleIfThen.KeyValuePair pair = getMaxPair(afterMap, data, copiedSet, attributes, copiedMap, zvalue);
				boolean bFound = true;
				if (pair != null) {
					pairSet.add(pair);
					copiedMap = afterMap.get(pair.key).get(pair.value);
					copiedSet = createRemainSet(copiedMap);
					attributes.get(pair.key).remove(pair.value);
					if (attributes.get(pair.key).size() == 0) {
						attributes.remove(pair.key);
					}
					// 条件を追加した結果、判定値が1つしか残らなかった場合はこれ以上条件を追加しない
					if (copiedMap.keySet().size() > 1) {
						bFound = false;
					}
				}

				if (bFound) {
					String sr = retrieveMostLikely(copiedMap);
					if (sr != null) {
						int count = 0;
						if (copiedMap.get(sr) != null) {
							count = copiedMap.get(sr).size();
						}
						RuleIfThen rule = new RuleIfThen(sr, count, countMap(copiedMap));
						rule.rules = pairSet;
						concept.add(rule);
					}
					remainSet.removeAll(copiedSet);
					subtract(beforeMap, copiedMap);
					break;
				}
			}
		}
	}
}

Ripper は If <A=a & B=b ...> Then Z = z という条件式を
多段に重ねて知識を表現する。

Weka 付属の weather.nominal.arff ファイルを食わせてみると、
実行効率は悪いけど、一応動いているみたい。

result:
if (humidity=high) and (outlook=sunny):no(3/3)
if (outlook=rainy) and (windy=TRUE):no(2/2)
if :yes(9/9)

NullPointerException に悩まされながらここまで来るのに3日かかった。
ふぅ~。