徒然なる日々を送るソフトウェアデベロッパーの記録(2)

技術上思ったことや感じたことを気ままに記録していくブログです。さくらから移設しました。

今更、JCuda で AutoEncoder を実装してみた

JCuda の使い勝手を研究するため、AutoEncoder のさわりの部分を実装してみました。
まだバグはいそうですが...
CUDA Driver インタフェースを使うため、初期化と kernel 呼び出しが多少面倒ですが、
メソッドに押し込めてしまうことにします。

まず、AutoEncoder.java のコード。

import jcuda.Sizeof;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;

import static jcuda.driver.JCudaDriver.*;

import java.io.IOException;

import jcuda.Pointer;

// AutoEncoder クラス
public class AutoEncoder {
	private static final int NTHREAD = 16;	// GPU スレッド数; 実際にはこの2乗ある
	private static final int SHARED_SIZE = NTHREAD * (NTHREAD + 1) * Sizeof.FLOAT;	// shared memory サイズ
	private static final String  PTXFILENAME = "AutoEncoder.cu";
	
	int m;	// 中間ノード数
	int n;	// 入力/出力ノード数
	CUdeviceptr devX, devY, devZ;	// 入力値、中間値、出力値
	CUdeviceptr devB1, devB2;		// 定数ノード
	CUdeviceptr devW;	// パーセプトロンの結合係数
	CUdeviceptr devWd, devB1d, devB1t, devB2d, devB2t;	// 逆誤差伝播用
	String ptxFileName;	// CUDA コードファイル名
	CUfunction fClear, fClearW;
	CUfunction fEncode, fDecode, fShaper, fCalcB2d, fCalcB1d, fCalcWd, fForward, fForwardW;	// エントリポイント
	
	// CUDA カーネルの呼び出し
	protected static void callKernel(CUfunction func, Pointer param, int num, int sharedSize) {
		int block = (num + NTHREAD - 1) / NTHREAD;
		cuLaunchKernel(func,
			block, 1, 1,	// block
			NTHREAD, 1, 1,	// thread
			sharedSize, null,	// shared size, stream
			param, null		// kernel parameters
				);
		cuCtxSynchronize();
	}
	
	protected static void callKernel(CUfunction func, Pointer param, int num) {
		callKernel(func, param, num, 0);
	}
	
	// CUDA カーネル呼び出し(二次元版)
	protected static void callKernel2(CUfunction func, Pointer param, int numX, int numY) {
		int blockX = (numX + NTHREAD - 1) / NTHREAD;
		int blockY = (numY + NTHREAD - 1) / NTHREAD;
		cuLaunchKernel(func,
				blockX, blockY, 1,		// block
				NTHREAD, NTHREAD, 1,	// thread
				SHARED_SIZE, null,		// shared size, stream
				param, null				// kernel parameters
				);
		cuCtxSynchronize();
	}
	
	// コンストラクタ
	public AutoEncoder(int n, int m) throws IOException {
		this.m = m;
		this.n = n;
		JCudaDriver.setExceptionsEnabled(true);
		ptxFileName = Utils.preparePtxFile(PTXFILENAME);
		
		// CUDA 初期化
		cuInit(0);
		CUcontext pctx = new CUcontext();
		CUdevice dev = new CUdevice();
		cuDeviceGet(dev, 0);
		cuCtxCreate(pctx, 0, dev);
		
		// PTX ファイルをロードする
		CUmodule module = new CUmodule();
		cuModuleLoad(module, ptxFileName);
		
		// 関数のエントリポイントを取得する
		fClear = new CUfunction();
		cuModuleGetFunction(fClear, module, "clear");
		fClearW = new CUfunction();
		cuModuleGetFunction(fClearW, module, "clearW");
		fEncode = new CUfunction();
		cuModuleGetFunction(fEncode, module, "encode");
		fDecode = new CUfunction();
		cuModuleGetFunction(fDecode, module, "decode");
		fShaper = new CUfunction();
		cuModuleGetFunction(fShaper, module, "shaper");
		fCalcB2d = new CUfunction();
		cuModuleGetFunction(fCalcB2d, module, "calcB2d");
		fCalcB1d = new CUfunction();
		cuModuleGetFunction(fCalcB1d, module, "calcB1d");
		fCalcWd = new CUfunction();
		cuModuleGetFunction(fCalcWd, module, "calcWd");
		fForward = new CUfunction();
		cuModuleGetFunction(fForward, module, "forward");
		fForwardW = new CUfunction();
		cuModuleGetFunction(fForwardW, module, "forwardW");
		
		devX = new CUdeviceptr();
		cuMemAlloc(devX, n * Sizeof.FLOAT);
		devY = new CUdeviceptr();
		cuMemAlloc(devY, m * Sizeof.FLOAT);
		devZ = new CUdeviceptr();
		cuMemAlloc(devZ, n * Sizeof.FLOAT);
		devB1 = new CUdeviceptr();
		cuMemAlloc(devB1, m * Sizeof.FLOAT);
		devB2 = new CUdeviceptr();
		cuMemAlloc(devB2, n * Sizeof.FLOAT);
		devB1d = new CUdeviceptr();
		cuMemAlloc(devB1d, m * Sizeof.FLOAT);
		devB1t = new CUdeviceptr();
		cuMemAlloc(devB1t, m * Sizeof.FLOAT);
		devB2d = new CUdeviceptr();
		cuMemAlloc(devB2d, n * Sizeof.FLOAT);
		devB2t = new CUdeviceptr();
		cuMemAlloc(devB2t, n * Sizeof.FLOAT);
		devW = new CUdeviceptr();
		cuMemAlloc(devW, m * Sizeof.POINTER);
		devWd = new CUdeviceptr();
		cuMemAlloc(devWd, m * Sizeof.POINTER);
		CUdeviceptr[] wtmp = new CUdeviceptr[m];
		for (int i = 0; i < m; ++i) {
			wtmp[i] = new CUdeviceptr();
			cuMemAlloc(wtmp[i], n * Sizeof.FLOAT);
			
			// 係数 w をランダムに初期化する
			float[] winit = new float[n];
			for (int j = 0; j < n; ++j) {
				winit[j] = (float)((Math.random() - 0.5) * 2.0 * 0.01);
			}
			cuMemcpyHtoD(wtmp[i], Pointer.to(winit), n * Sizeof.FLOAT);
		}
		cuMemcpyHtoD(devW, Pointer.to(wtmp), m * Sizeof.POINTER);
		CUdeviceptr[] wdtmp = new CUdeviceptr[m];
		for (int i = 0; i < m; ++i) {
			wdtmp[i] = new CUdeviceptr();
			cuMemAlloc(wdtmp[i], n * Sizeof.FLOAT);
		}
		cuMemcpyHtoD(devWd, Pointer.to(wdtmp), m * Sizeof.POINTER);
		initBias();
	}
	
	@Override
	public void finalize() {
		cuMemFree(devX);
		cuMemFree(devY);
		cuMemFree(devZ);
		cuMemFree(devB1);
		cuMemFree(devB2);
		cuMemFree(devB1t);
		cuMemFree(devB1d);
		cuMemFree(devB2t);
		cuMemFree(devB2d);
		CUdeviceptr[] ptr = new CUdeviceptr[m];
		cuMemcpyDtoH(Pointer.to(ptr), devW, m * Sizeof.POINTER);
		for (int i = 0; i < m; ++i) {
			cuMemFree(ptr[i]);
		}
		cuMemFree(devW);
		cuMemcpyDtoH(Pointer.to(ptr), devWd, m * Sizeof.POINTER);
		for (int i = 0; i < m; ++i) {
			cuMemFree(ptr[i]);
		}
		cuMemFree(devWd);
	}
	
	// CUDA スレッドの終了を待つ
	public void sync() {
		cuCtxSynchronize();
	}
	
	// devX に入力データをセットする
	public void setInput(float[] in) {
		assert(in.length == n);
		cuMemcpyHtoD(devX, Pointer.to(in), n * Sizeof.FLOAT);
	}
	
	// devZ から出力値を取得する
	public float[] getOutput() {
		float[] out = new float[n];
		cuMemcpyDtoH(Pointer.to(out), devZ, n * Sizeof.FLOAT);
		return out;
	}
	
	// devY から中間値を取得する
	public float[] getIntermediate() {
		float[] im = new float[m];
		cuMemcpyDtoH(Pointer.to(im), devY, m * Sizeof.FLOAT);
		return im;
	}
	
	// 入力データを中間値にエンコードする
	public void encode() {
		// 中間値を 0 クリアする
		Pointer p = Pointer.to(Pointer.to(devY), Pointer.to(new int[]{m}));
		callKernel(fClear, p, m);
		
		// encode を実行する
		p = Pointer.to(Pointer.to(devW), Pointer.to(devX), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}), Pointer.to(devY));
		callKernel2(fEncode, p, n, m);
		
		// 定数ノードの影響および sigmoid 関数を通す
		p = Pointer.to(Pointer.to(devB1), Pointer.to(new int[]{m}), Pointer.to(devY));
		callKernel(fShaper, p, m);
	}
	
	// 中間値を出力値にデコードする
	public void decode() {
		// 出力値を 0 クリアする
		Pointer p = Pointer.to(Pointer.to(devZ), Pointer.to(new int[]{n}));
		callKernel(fClear, p, n);
		
		// decode を実行する
		p = Pointer.to(Pointer.to(devW), Pointer.to(devY), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}), Pointer.to(devZ));
		callKernel2(fDecode, p, n, m);
		
		// 定数ノードの影響および sigmoid 関数を通す
		p = Pointer.to(Pointer.to(devB2), Pointer.to(new int[]{n}), Pointer.to(devZ));
		callKernel(fShaper, p, n);
	}
	
	// 誤差領域を初期化する
	public void initDiff() {
		Pointer p = Pointer.to(Pointer.to(devB1d), Pointer.to(new int[]{m}));
		callKernel(fClear, p, m);
		p = Pointer.to(Pointer.to(devB2d), Pointer.to(new int[]{n}));
		callKernel(fClear, p, n);
		p = Pointer.to(Pointer.to(devWd), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}));
		callKernel2(fClearW, p, n, m);
	}
	
	// 逆誤差散乱
	public void updateDiff() {
		// B1t の 0 クリア
		Pointer p = Pointer.to(Pointer.to(devB1t), Pointer.to(new int[]{m}));
		callKernel(fClear, p, m);
		
		// B2t, B2d のアップデート
		p = Pointer.to(Pointer.to(devX), Pointer.to(devZ), Pointer.to(new int[]{n}), Pointer.to(devB2t), Pointer.to(devB2d));
		callKernel(fCalcB2d, p, n);
		
		// B1t, B1d のアップデート
		p = Pointer.to(Pointer.to(devW), Pointer.to(devY), Pointer.to(devB2t), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}),
				Pointer.to(devB1t), Pointer.to(devB1d));
		callKernel2(fCalcB1d, p, n, m);
		
		// Wd のアップデート
		p = Pointer.to(Pointer.to(devB1t), Pointer.to(devB2t), Pointer.to(devX), Pointer.to(devY),
				Pointer.to(new int[]{n}), Pointer.to(new int[]{m}), Pointer.to(devWd));
		callKernel2(fCalcWd, p, n, m);
	}
	
	// 定数バイアスの初期化
	public void initBias() {
		Pointer p = Pointer.to(Pointer.to(devB1), Pointer.to(new int[]{m}));
		callKernel(fClear, p, m);
		p = Pointer.to(Pointer.to(devB2), Pointer.to(new int[]{n}));
		callKernel(fClear, p, n);
	}
	
	// 定数バイアス、結合係数の更新
	public void updateBias(float rate) {
		// b2 のアップデート
		Pointer p = Pointer.to(Pointer.to(devB2d), Pointer.to(new float[]{rate}), Pointer.to(new int[]{n}), Pointer.to(devB2));
		callKernel(fForward, p, n);
		
		// b1 のアップデート
		p = Pointer.to(Pointer.to(devB1d), Pointer.to(new float[]{rate}), Pointer.to(new int[]{m}), Pointer.to(devB1));
		callKernel(fForward, p, m);
		
		// w のアップデート
		p = Pointer.to(Pointer.to(devWd), Pointer.to(new float[]{rate}), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}),
				Pointer.to(devW));
		callKernel2(fForwardW, p, n, m);
	}
	
	// トレーニング
	// pCnt: パターン数
	// tCnt: トレーニング回数
	// bSize: トレーニングバッチサイズ
	// rate: 学習率
	public void training(float[][] data, int pCnt, int tCnt, int bSize, float rate) {
		int count = 0;
		for (int i = 0; i < pCnt * tCnt / bSize; ++i) {
			// 誤差を初期化する
			initDiff();
			
			// バッチループ
			for (int j = 0; j < bSize; ++j) {
				int index = (j + i * bSize) % pCnt;
				setInput(data[index]);
				encode();
				decode();
				updateDiff();

				if (count % 100 == 100 - 1) {
					System.out.print(".");
					if (count % 10000 == 10000 - 1) {
						System.out.println(count + 1);
					}
				}
				++count;
			}
			// 結合係数と定数ノードを更新する
			updateBias(rate);
		}
		System.out.println("trained");
	}
}

気をつけるべき点は、二重配列でしょうか。
float[][] は float ** に変換されますが、デバイスメモリでは
float * の配列として定義する必要があります。

続いて CUDA コード。

// 以下、n: 入力値/出力値の個数, m: 中間ノードの個数とする。
const int NTHREAD = 16;

#define IDX(x, y) (x + y * NTHREAD)
#define IDXX(x) (x + NTHREAD * NTHREAD)

// sigmoid 関数
__device__
float sigmoid(float x) {
  return 1.0f / (1.0f + exp(-x));
}

// バッファを 0 クリアする
extern "C"
__global__ void clear(float *buf, int n) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  if (i < n) {
    buf[i] = 0.0f;
  }
}

// W バッファを 0 クリアする
extern "C"
__global__ void clearW(float **buf, int n, int m) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  int j = threadIdx.y + blockIdx.y * blockDim.y;
  
  if (i < n && j < m) {
    buf[j][i] = 0.0f;
  }
}

// 入力データのエンコード
// 最初に out を 0 クリアしておくこと
// 出力値を得るには更に shape を通す必要がある
extern "C"
__global__ void encode(float **w, float *in, int n, int m, float *out) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  int j = threadIdx.y + blockIdx.y * blockDim.y;
  int i2 = blockDim.x / 2;
  extern __shared__ float s[];
  
  // 入力領域にキャッシュする
  if (threadIdx.y == 0 && i < n) {
    s[IDXX(threadIdx.x)] = in[i];
  }
  __syncthreads();
  
  // 結合係数をかける
  if (i < n && j < m) {
    s[IDX(threadIdx.x, threadIdx.y)] = w[j][i] * s[IDXX(threadIdx.x)];
  }
  __syncthreads(); 
  
  // 縮約する
  while (i2 > 0) {
    if (threadIdx.x < i2 && i + i2 < n && j < m) {
      s[IDX(threadIdx.x, threadIdx.y)] += s[IDX(threadIdx.x + i2, threadIdx.y)];
    }
    __syncthreads();
    i2 = i2 / 2;
  }
  
  // 結果を更新する
  if (threadIdx.x == 0 && j < m) {
    atomicAdd(&out[j], s[IDX(0, threadIdx.y)]);
  }
  __syncthreads();
}

// 中間データのデコード
// 最初に out を 0 クリアしておくこと
// 出力値を得るには更に shape を通す必要がある
extern "C"
__global__ void decode(float **w, float *in, int n, int m, float *out) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  int j = threadIdx.y + blockIdx.y * blockDim.y;
  int j2 = blockDim.y / 2;
  extern __shared__ float s[];

  // 入力領域にキャッシュする
  if (threadIdx.x == 0 && j < m) {
    s[IDXX(threadIdx.y)] = in[j];
  }
  __syncthreads();
  
  // 結合係数をかける
  if (i < n && j < m) {
    s[IDX(threadIdx.y, threadIdx.x)] = w[j][i] * s[IDXX(threadIdx.y)];
  }
  __syncthreads();
  
  // 縮約する
  while (j2 > 0) {
    if (threadIdx.y < j2 && j + j2 < m && i < n) {
      s[IDX(threadIdx.y, threadIdx.x)] += s[IDX(threadIdx.y + j2, threadIdx.x)];
    }
    __syncthreads();
    j2 = j2 / 2;
  }

  // 結果を更新する
  if (threadIdx.y == 0 && i < n) {
    atomicAdd(&out[i], s[IDX(0, threadIdx.x)]);
  }
  __syncthreads();
}

// 定数ノードを考慮し、sigmoid 関数で (0,1) にマッピングする
extern "C"
__global__ void shaper(float *b, int n, float *out) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  if (i < n) {
    out[i] = sigmoid(out[i] + b[i]);
  }
}

// 出力に関する定数ノードの更新値を計算する
extern "C"
__global__ void calcB2d(float *x, float *z, int n, float *b2t, float *b2d) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  if (i < n) {
    b2t[i] = x[i] - z[i];
    b2d[i] += b2t[i];
  }
}

// 入力に関する定数ノードの更新値を計算する
extern "C"
__global__ void calcB1d(float **w, float *y, float *b2t, int n, int m, float *b1t, float *b1d) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  int j = threadIdx.y + blockIdx.y * blockDim.y;
  int i2 = blockDim.x / 2;
  extern __shared__ float s[];
  
  // 入力バッファへのロード
  if (threadIdx.y == 0 && i < n) {
    s[IDXX(threadIdx.x)] = b2t[i];
  }
  __syncthreads();
  
  // 結合係数の作用の計算
  if (i < n && j < m) {
    s[IDX(threadIdx.x, threadIdx.y)] = w[j][i] * s[IDXX(threadIdx.x)];
  }
  __syncthreads();
  
  // 縮約
  while (i2 > 0) {
    if (threadIdx.x < i2 && i + i2 < n && j < m) {
      s[IDX(threadIdx.x, threadIdx.y)] += s[IDX(threadIdx.x + i2, threadIdx.y)];
    }
    __syncthreads();
    i2 = i2 / 2;
  }
  
  // 計算値を保存
  if (threadIdx.x == 0 && j < m) {
    float tmp = s[IDX(0, threadIdx.y)] * y[j] * (1.0f - y[j]);
    atomicAdd(&b1t[j], tmp);
    atomicAdd(&b1d[j], tmp);
  }
  __syncthreads();
}

// w の更新値を計算する
extern "C"
__global__ void calcWd(float *b1t, float *b2t, float *x, float *y, int n, int m, float **wd) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  int j = threadIdx.y + blockIdx.y * blockDim.y;
  
  if (i < n && j < m) {
    wd[j][i] += b1t[j] * x[i] + b2t[i] * y[j];
  }
}

// b1, b2 を更新する
// rate は学習率
extern "C"
__global__ void forward(float *bd, float rate, int n, float *b) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  if (i < n) {
    b[i] += rate * bd[i];
  }
}

// w を更新する
// rate は学習率
extern "C"
__global__ void forwardW(float **wd, float rate, int n, int m, float **w) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  int j = threadIdx.y + blockIdx.y * blockDim.y;
  
  if (i < n && j < m) {
    w[j][i] += rate * wd[j][i];
  }
}

高速化のために shared memory を使いまくっています。
逆に言うと、キャッシュしないと GPU を使っても
あまり速くなりません。

最後にアプリケーションを起動するための Java コード。

import java.io.IOException;

public class AutoEncoderApp {
	private static final int N_NODE = 28 * 28;
	private static final int N_INTERMEDIATE = 400;
	private static final int N_PCNT = 6000;
	private static final int N_TCNT = 10;
	private static final int BATCHSIZE = 20;
	private static final float LEARNRATE = 0.01f;
	
	AutoEncoder ae;
	float[][] images;
	
	// 実行メソッド
	public void run() {
		// トレーニングを実施
		ae.training(images, N_PCNT, N_TCNT, BATCHSIZE, LEARNRATE);
		ae.sync();
		
		// トレーニング結果を取り出し
		float[][] results = new float[N_PCNT][];
		for (int i = 0; i < N_PCNT; ++i) {
			ae.setInput(images[i]);
			ae.encode();
			ae.decode();
			ae.sync();
			results[i] = ae.getOutput();
		}
		
		// トレーニング結果を表示
		new Visualizer(28, 28, 10, 10, 1.0f).dispDataImage(results, N_PCNT);
	}
	
	// コンストラクタ
	public AutoEncoderApp(String fname) throws IOException {
		ae = new AutoEncoder(N_NODE, N_INTERMEDIATE);
		images = DATA.readImage(fname, N_PCNT);
	}
	
	// メイン関数
	public static void main(String[] args) throws IOException {
		new AutoEncoderApp("train-images-idx3-ubyte").run();
	}
}

DATA クラスや Visualizer クラスは以前紹介した Java AI プログラミング本
で定義されているクラスです。ファイルからの MNIST データの読み込み、
結果のグラフィック表示を実行します。
Utils は ptx ファイルがなければ nvcc を起動して作成するクラスです。

パターン数 6000 くらいにしても Core i7 6950X では全く進まなかった計算が
GTX 1080 使用で 30 秒くらいで返ってくるようになりました。

JCuda は Hybrid プログラミングとなりますが、慣れてしまえば C++
フルに書くよりも効率的に書けそうです。今回のようなちょっとした実験
には向いていると感じました。