隠れ層のないニューラルネットを実装してみた
今更感満載ですが、勉強がてら2層のニューラルネットを Java で実装してみた。
教科書は例によってこの本です。
- 作者: 荒木雅弘
- 出版社/メーカー: 森北出版
- 発売日: 2014/03/29
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (4件) を見る
次のページにあるアルゴリズム中の式が正解。
また、δの計算に使用する重みは自分より1つ下ものを使用します。
式で書くと
ただし、は自分より1つ下のノードで使用される線形結合
の係数。
初期値については記述がないが、区間 (-0.5, 0.5) のランダム
値を使用することにした。(議論の余地あり)
出力層のノード数はクラス数とする。
第1層のノード数は特徴数を取るのがよさそうだ。
識別器の出力は、最も大きい出現確率を返すクラス値とする。
ThreeLayerNeuralNet.java
package com.minosys.aplication; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; /** * 3層ニューラルネットワーク * 中間層のノード数は class 数固定にしてある * パラメータの決定に誤差逆伝播法を使用する * * @author minosys * */ public class ThreeLayerNeuralNet extends AbstractClassifier implements Classifier { private final static double CONVERGE_ERROR = 1e-4; private final static double FACTOR = 1.0 - 1e-4; // 第1層のキーは番号、最下層のキーはクラス値とする List<Map<String, Perceptron>> concept; double eta = 0.3; // concept を初期化して返す private List<Map<String, Perceptron>> createConcept(int nattr, Set<String> set) { List<Map<String, Perceptron>> c = new ArrayList<Map<String, Perceptron>>(); Map<String, Perceptron> map = new HashMap<String, Perceptron>(); // 第1層; 定義は特徴数とする for (int n = 0; n < nattr; ++n) { map.put(Integer.toString(n), new Perceptron(nattr)); } c.add(map); // 第2層; 定義は class 数とする map = new HashMap<String, Perceptron>(); for (Iterator<String> s = set.iterator(); s.hasNext(); ) { map.put(s.next(), new Perceptron(c.get(0).size())); } c.add(map); return c; } // 文字列を数値に変換する; 最後のデータはクラス値なので含めない private List<Double> convert(List<String> src) { List<Double> target = new ArrayList<Double>(); for (int i = 0; i < src.size() - 1; ++i) { target.add(Double.valueOf(src.get(i))); } return target; } // 誤差を格納する領域を初期化して返す private List<Map<String, Double>> initializeDelta() { List<Map<String, Double>> target = new ArrayList<Map<String, Double>>(); for (Iterator<Map<String, Perceptron>> i = concept.iterator(); i.hasNext(); ) { Map<String, Perceptron> map = i.next(); Map<String, Double> pmap = new HashMap<String, Double>(); for (Iterator<String> j = map.keySet().iterator(); j.hasNext(); ) { pmap.put(j.next(), 0.0); } target.add(pmap); } return target; } // 各 perceptron の出力を計算する private void calcOutput(List<Map<String, Double>> map, List<Double> x) { if (map.size() < 2) { map.add(new HashMap<String, Double>()); map.add(new HashMap<String, Double>()); } // 第1層の出力 List<Double> o1 = new ArrayList<Double>(); Map<String, Double> pmap = map.get(0); for (Iterator<Map.Entry<String, Perceptron>> i = concept.get(0).entrySet().iterator(); i.hasNext(); ) { Map.Entry<String, Perceptron> e = i.next(); double output = e.getValue().calc(x); o1.add(output); pmap.put(e.getKey(), output); } map.add(pmap); // 第2層の出力 pmap = map.get(1); for (Iterator<Map.Entry<String, Perceptron>> i = concept.get(1).entrySet().iterator(); i.hasNext(); ) { Map.Entry<String, Perceptron> e = i.next(); double output = e.getValue().calc(o1); pmap.put(e.getKey(), output); } map.add(pmap); } @Override public void analyze(GenericData data) { // TODO 自動生成されたメソッド・スタブ Set<String> classSet = data.attributes.get(data.attributeNameList.get(data.lastIindex())); int layerSize = classSet.size(); concept = createConcept(data.attributeNameList.size() - 1, classSet); // 入力データ列をランダム化する List<Integer> seq = new ArrayList<Integer>(); for (int i = 0; i < data.data.size(); ++i) { seq.add(i); } Collections.shuffle(seq); // 各データについて List<Map<String, Double>> output = new ArrayList<Map<String, Double>>(); List<Map<String, Double>> delta = initializeDelta(); for (int i = 0; i < data.data.size() && eta > CONVERGE_ERROR; ++i, eta = eta * FACTOR) { List<String> d = data.data.get(seq.get(i)); List<Double> x = convert(d); while (true) { // NeuralNet の暫定出力を求める calcOutput(output, x); // 最下層の perceptron のパラメータを求める for (Iterator<String> j = classSet.iterator(); j.hasNext(); ) { String s = j.next(); double y = (s.equals(d.get(d.size() - 1)))?1.0:0.0; double o_k = output.get(1).get(s); double err = o_k * (1.0 - o_k) * (y - o_k); delta.get(1).put(s, err); } // 第1層の perceptron のパラメータを求める for (Iterator<String> j = delta.get(0).keySet().iterator(); j.hasNext(); ) { String s = j.next(); double err = 0.0; for (Iterator<Map.Entry<String, Double>> k = delta.get(1).entrySet().iterator(); k.hasNext(); ) { Map.Entry<String, Double> e = k.next(); err = err + delta.get(1).get(e.getKey()) * concept.get(1).get(e.getKey()).w.get(Integer.valueOf(s)); } double o_k = output.get(0).get(s); err = o_k * (1.0 - o_k) * err; delta.get(0).put(s, err); } // 各層の修正値を求める double maxerr = 0.0; // 第1層 for (Iterator<String> j = concept.get(0).keySet().iterator(); j.hasNext(); ) { String si = j.next(); Perceptron pc = concept.get(0).get(si); for (int w = 0; w < pc.w.size(); ++w) { double err = +eta * delta.get(0).get(si) * ((w == pc.w.size() - 1)?1.0:x.get(w)); maxerr = Math.max(maxerr, Math.abs(err)); err = pc.w.get(w) + err; pc.w.set(w, err); } } // 最下層 for (Iterator<String> j = concept.get(1).keySet().iterator(); j.hasNext(); ) { String s = j.next(); Perceptron pc = concept.get(1).get(s); for (int w = 0; w < pc.w.size(); ++w) { double err = +eta * delta.get(1).get(s) * ((w == pc.w.size() - 1)?1.0:output.get(0).get(Integer.toString(w))); maxerr = Math.max(maxerr, Math.abs(err)); err = pc.w.get(w) + err; pc.w.set(w, err); } } if (maxerr < CONVERGE_ERROR) { System.out.print("*"); break; } } } } @Override public ValueWithProbability decide(List<String> slist, Map<String, Integer> amap) { // TODO 自動生成されたメソッド・スタブ List<Map<String, Double>> maplist = new ArrayList<Map<String, Double>>(); List<Double> x = convert(slist); // 各 perceptron の出力を計算 calcOutput(maplist, x); // 最下層で最も高い出力値のクラス値を識別器の出力とする String symbol = null; double prob = 0.0; for (Iterator<Map.Entry<String, Double>> i = maplist.get(1).entrySet().iterator(); i.hasNext(); ) { Map.Entry<String, Double> e = i.next(); if (e.getValue() > prob) { symbol = e.getKey(); prob = e.getValue(); } } ValueWithProbability vp = new ValueWithProbability(symbol, prob); return vp; } }
Perceptron.java
package com.minosys.aplication; import java.util.ArrayList; import java.util.List; public class Perceptron { public List<Double> w; private double generator() { return (Math.random() - 1.0) * 0.5; } public Perceptron(int sz) { w = new ArrayList<Double>(); for (int i = 0; i < sz; ++i) { w.add(generator()); } w.add(generator()); } private double sigmoid(double x) { return 1.0 / (1.0 + Math.exp(-x)); } public double calc(List<Double> x) { double pr = 0.0; for (int i = 0; i < x.size(); ++i) { pr = pr + x.get(i) * w.get(i); } pr = pr + w.get(w.size() - 1); return sigmoid(pr); } public String toString(List<String> names) { StringBuffer sb = new StringBuffer(); sb.append(String.format("%.4g", w.get(w.size() - 1))); for (int i = 0; i < w.size() - 1; ++i) { if (w.get(i) >= 0.0) { sb.append("+"); } else { sb.append("-"); } sb.append(String.format("%.4g", Math.abs(w.get(i)))).append("*[").append(names.get(i)).append("]"); } return sb.toString(); } @Override public String toString() { StringBuffer sb = new StringBuffer(); sb.append(String.format("%.4g", w.get(w.size() - 1))); for (int i = 0; i < w.size() - 1; ++i) { if (w.get(i) >= 0.0) { sb.append("+"); } else { sb.append("-"); } sb.append(String.format("%.4g", Math.abs(w.get(i)))).append("*[").append(i).append("]"); } return sb.toString(); } }
最急降下法がなかなか収束しなくて苦労した。
誤差関数が最小を通り過ぎていそうだったら2分検索法で最小値を
探すことも考えたが、結局シンプルな方式にした。
当初シンプルな方法では収束しないのではないかと思ったが、
時間はかかるものの、意外に収束することが分かった。