読者です 読者をやめる 読者になる 読者になる

徒然なる日々を送るソフトウェアデベロッパーの記録(2)

技術上思ったことや感じたことを気ままに記録していくブログです。さくらから移設しました。

隠れ層のないニューラルネットを実装してみた

今更感満載ですが、勉強がてら2層のニューラルネットJava で実装してみた。
教科書は例によってこの本です。

フリーソフトではじめる機械学習入門

フリーソフトではじめる機械学習入門

底本にした第1版 p.112 6.12 式は第2項の符号が間違っています。
次のページにあるアルゴリズム中の式が正解。
また、δの計算に使用する重みは自分より1つ下ものを使用します。
式で書くと
{
\delta_p = \sum_k{w_{pk} \delta_k }
}
ただし、 w_{pk}は自分より1つ下のノードで使用される線形結合
の係数。

初期値については記述がないが、区間 (-0.5, 0.5) のランダム
値を使用することにした。(議論の余地あり)
出力層のノード数はクラス数とする。
第1層のノード数は特徴数を取るのがよさそうだ。
識別器の出力は、最も大きい出現確率を返すクラス値とする。

ThreeLayerNeuralNet.java
package com.minosys.aplication;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * 3層ニューラルネットワーク
 * 中間層のノード数は class 数固定にしてある
 * パラメータの決定に誤差逆伝播法を使用する
 *
 * @author minosys
 *
 */
public class ThreeLayerNeuralNet extends AbstractClassifier implements
		Classifier {
	private final static double CONVERGE_ERROR = 1e-4;
	private final static double FACTOR = 1.0 - 1e-4;

	// 第1層のキーは番号、最下層のキーはクラス値とする
	List<Map<String, Perceptron>> concept;
	double eta = 0.3;

	// concept を初期化して返す
	private List<Map<String, Perceptron>> createConcept(int nattr, Set<String> set) {
		List<Map<String, Perceptron>> c = new ArrayList<Map<String, Perceptron>>();
		Map<String, Perceptron> map = new HashMap<String, Perceptron>();

		// 第1層; 定義は特徴数とする
		for (int n = 0; n < nattr; ++n) {
			map.put(Integer.toString(n), new Perceptron(nattr));
		}
		c.add(map);

		// 第2層; 定義は class 数とする
		map = new HashMap<String, Perceptron>();
		for (Iterator<String> s = set.iterator(); s.hasNext(); ) {
			map.put(s.next(), new Perceptron(c.get(0).size()));
		}
		c.add(map);
		return c;
	}

	// 文字列を数値に変換する; 最後のデータはクラス値なので含めない
	private List<Double> convert(List<String> src) {
		List<Double> target = new ArrayList<Double>();
		for (int i = 0; i < src.size() - 1; ++i) {
			target.add(Double.valueOf(src.get(i)));
		}
		return target;
	}

	// 誤差を格納する領域を初期化して返す
	private List<Map<String, Double>> initializeDelta() {
		List<Map<String, Double>> target = new ArrayList<Map<String, Double>>();
		for (Iterator<Map<String, Perceptron>> i = concept.iterator(); i.hasNext(); ) {
			Map<String, Perceptron> map = i.next();
			Map<String, Double> pmap = new HashMap<String, Double>();
			for (Iterator<String> j = map.keySet().iterator(); j.hasNext(); ) {
				pmap.put(j.next(), 0.0);
			}
			target.add(pmap);
		}
		return target;
	}

	// 各 perceptron の出力を計算する
	private void calcOutput(List<Map<String, Double>> map, List<Double> x) {
		if (map.size() < 2) {
			map.add(new HashMap<String, Double>());
			map.add(new HashMap<String, Double>());
		}

		// 第1層の出力
		List<Double> o1 = new ArrayList<Double>();
		Map<String, Double> pmap = map.get(0);
		for (Iterator<Map.Entry<String, Perceptron>> i = concept.get(0).entrySet().iterator(); i.hasNext(); ) {
			Map.Entry<String, Perceptron> e = i.next();
			double output = e.getValue().calc(x);
			o1.add(output);
			pmap.put(e.getKey(), output);
		}
		map.add(pmap);

		// 第2層の出力
		pmap = map.get(1);
		for (Iterator<Map.Entry<String, Perceptron>> i = concept.get(1).entrySet().iterator(); i.hasNext(); ) {
			Map.Entry<String, Perceptron> e = i.next();
			double output = e.getValue().calc(o1);
			pmap.put(e.getKey(), output);
		}
		map.add(pmap);
	}

	@Override
	public void analyze(GenericData data) {
		// TODO 自動生成されたメソッド・スタブ
		Set<String> classSet = data.attributes.get(data.attributeNameList.get(data.lastIindex()));
		int layerSize = classSet.size();
		concept = createConcept(data.attributeNameList.size() - 1, classSet);

		// 入力データ列をランダム化する
		List<Integer> seq = new ArrayList<Integer>();
		for (int i = 0; i < data.data.size(); ++i) {
			seq.add(i);
		}
		Collections.shuffle(seq);

		// 各データについて
		List<Map<String, Double>> output = new ArrayList<Map<String, Double>>();
		List<Map<String, Double>> delta = initializeDelta();

		for (int i = 0; i < data.data.size() && eta > CONVERGE_ERROR; ++i, eta = eta * FACTOR) {
			List<String> d = data.data.get(seq.get(i));
			List<Double> x = convert(d);

			while (true) {
				// NeuralNet の暫定出力を求める
				calcOutput(output, x);

				// 最下層の perceptron のパラメータを求める
				for (Iterator<String> j = classSet.iterator(); j.hasNext(); ) {
					String s = j.next();
					double y = (s.equals(d.get(d.size() - 1)))?1.0:0.0;
					double o_k = output.get(1).get(s);
					double err = o_k * (1.0 - o_k) * (y - o_k);
					delta.get(1).put(s, err);
				}

				// 第1層の perceptron のパラメータを求める
				for (Iterator<String> j = delta.get(0).keySet().iterator(); j.hasNext(); ) {
					String s = j.next();
					double err = 0.0;
					for (Iterator<Map.Entry<String, Double>> k = delta.get(1).entrySet().iterator(); k.hasNext(); ) {
						Map.Entry<String, Double> e = k.next();
						err = err + delta.get(1).get(e.getKey()) * concept.get(1).get(e.getKey()).w.get(Integer.valueOf(s));
					}
					double o_k = output.get(0).get(s);
					err = o_k * (1.0 - o_k) * err;
					delta.get(0).put(s, err);
				}

				// 各層の修正値を求める
				double maxerr = 0.0;
				// 第1層
				for (Iterator<String> j = concept.get(0).keySet().iterator(); j.hasNext(); ) {
					String si = j.next();
					Perceptron pc = concept.get(0).get(si);
					for (int w = 0; w < pc.w.size(); ++w) {
						double err = +eta * delta.get(0).get(si) * ((w == pc.w.size() - 1)?1.0:x.get(w));
						maxerr = Math.max(maxerr, Math.abs(err));
						err = pc.w.get(w) + err;
						pc.w.set(w, err);
					}
				}
				// 最下層
				for (Iterator<String> j = concept.get(1).keySet().iterator(); j.hasNext(); ) {
					String s = j.next();
					Perceptron pc = concept.get(1).get(s);
					for (int w = 0; w < pc.w.size(); ++w) {
						double err = +eta * delta.get(1).get(s) * ((w == pc.w.size() - 1)?1.0:output.get(0).get(Integer.toString(w)));
						maxerr = Math.max(maxerr, Math.abs(err));
						err = pc.w.get(w) + err;
						pc.w.set(w, err);
					}
				}
				if (maxerr < CONVERGE_ERROR) {
					System.out.print("*");
					break;
				}
			}
		}
	}

	@Override
	public ValueWithProbability decide(List<String> slist,
			Map<String, Integer> amap) {
		// TODO 自動生成されたメソッド・スタブ
		List<Map<String, Double>> maplist = new ArrayList<Map<String, Double>>();
		List<Double> x = convert(slist);

		// 各 perceptron の出力を計算
		calcOutput(maplist, x);

		// 最下層で最も高い出力値のクラス値を識別器の出力とする
		String symbol = null;
		double prob = 0.0;
		for (Iterator<Map.Entry<String, Double>> i = maplist.get(1).entrySet().iterator(); i.hasNext(); ) {
			Map.Entry<String, Double> e = i.next();
			if (e.getValue() > prob) {
				symbol = e.getKey();
				prob = e.getValue();
			}
		}
		ValueWithProbability vp = new ValueWithProbability(symbol, prob);
		return vp;
	}
}
Perceptron.java
package com.minosys.aplication;

import java.util.ArrayList;
import java.util.List;

public class Perceptron {
	public List<Double> w;

	private double generator() {
		return (Math.random() - 1.0) * 0.5;
	}
	public Perceptron(int sz) {
		w = new ArrayList<Double>();
		for (int i = 0; i < sz; ++i) {
			w.add(generator());
		}
		w.add(generator());
	}

	private double sigmoid(double x) {
		return 1.0 / (1.0 + Math.exp(-x));
	}

	public double calc(List<Double> x) {
		double pr = 0.0;
		for (int i = 0; i < x.size(); ++i) {
			pr = pr + x.get(i) * w.get(i);
		}
		pr = pr + w.get(w.size() - 1);
		return sigmoid(pr);
	}

	public String toString(List<String> names) {
		StringBuffer sb = new StringBuffer();
		sb.append(String.format("%.4g", w.get(w.size() - 1)));
		for (int i = 0; i < w.size() - 1; ++i) {
			if (w.get(i) >= 0.0) {
				sb.append("+");
			} else {
				sb.append("-");
			}
			sb.append(String.format("%.4g", Math.abs(w.get(i)))).append("*[").append(names.get(i)).append("]");
		}
		return sb.toString();
	}

	@Override
	public String toString() {
		StringBuffer sb = new StringBuffer();
		sb.append(String.format("%.4g", w.get(w.size() - 1)));
		for (int i = 0; i < w.size() - 1; ++i) {
			if (w.get(i) >= 0.0) {
				sb.append("+");
			} else {
				sb.append("-");
			}
			sb.append(String.format("%.4g", Math.abs(w.get(i)))).append("*[").append(i).append("]");
		}
		return sb.toString();
	}
}

最急降下法がなかなか収束しなくて苦労した。
誤差関数が最小を通り過ぎていそうだったら2分検索法で最小値を
探すことも考えたが、結局シンプルな方式にした。
当初シンプルな方法では収束しないのではないかと思ったが、
時間はかかるものの、意外に収束することが分かった。