今更、JCuda で AutoEncoder を実装してみた
JCuda の使い勝手を研究するため、AutoEncoder のさわりの部分を実装してみました。
まだバグはいそうですが...
CUDA Driver インタフェースを使うため、初期化と kernel 呼び出しが多少面倒ですが、
メソッドに押し込めてしまうことにします。
まず、AutoEncoder.java のコード。
import jcuda.Sizeof; import jcuda.driver.CUcontext; import jcuda.driver.CUdevice; import jcuda.driver.CUdeviceptr; import jcuda.driver.CUfunction; import jcuda.driver.CUmodule; import jcuda.driver.JCudaDriver; import static jcuda.driver.JCudaDriver.*; import java.io.IOException; import jcuda.Pointer; // AutoEncoder クラス public class AutoEncoder { private static final int NTHREAD = 16; // GPU スレッド数; 実際にはこの2乗ある private static final int SHARED_SIZE = NTHREAD * (NTHREAD + 1) * Sizeof.FLOAT; // shared memory サイズ private static final String PTXFILENAME = "AutoEncoder.cu"; int m; // 中間ノード数 int n; // 入力/出力ノード数 CUdeviceptr devX, devY, devZ; // 入力値、中間値、出力値 CUdeviceptr devB1, devB2; // 定数ノード CUdeviceptr devW; // パーセプトロンの結合係数 CUdeviceptr devWd, devB1d, devB1t, devB2d, devB2t; // 逆誤差伝播用 String ptxFileName; // CUDA コードファイル名 CUfunction fClear, fClearW; CUfunction fEncode, fDecode, fShaper, fCalcB2d, fCalcB1d, fCalcWd, fForward, fForwardW; // エントリポイント // CUDA カーネルの呼び出し protected static void callKernel(CUfunction func, Pointer param, int num, int sharedSize) { int block = (num + NTHREAD - 1) / NTHREAD; cuLaunchKernel(func, block, 1, 1, // block NTHREAD, 1, 1, // thread sharedSize, null, // shared size, stream param, null // kernel parameters ); cuCtxSynchronize(); } protected static void callKernel(CUfunction func, Pointer param, int num) { callKernel(func, param, num, 0); } // CUDA カーネル呼び出し(二次元版) protected static void callKernel2(CUfunction func, Pointer param, int numX, int numY) { int blockX = (numX + NTHREAD - 1) / NTHREAD; int blockY = (numY + NTHREAD - 1) / NTHREAD; cuLaunchKernel(func, blockX, blockY, 1, // block NTHREAD, NTHREAD, 1, // thread SHARED_SIZE, null, // shared size, stream param, null // kernel parameters ); cuCtxSynchronize(); } // コンストラクタ public AutoEncoder(int n, int m) throws IOException { this.m = m; this.n = n; JCudaDriver.setExceptionsEnabled(true); ptxFileName = Utils.preparePtxFile(PTXFILENAME); // CUDA 初期化 cuInit(0); CUcontext pctx = new CUcontext(); CUdevice dev = new CUdevice(); cuDeviceGet(dev, 0); cuCtxCreate(pctx, 0, dev); // PTX ファイルをロードする CUmodule module = new CUmodule(); cuModuleLoad(module, ptxFileName); // 関数のエントリポイントを取得する fClear = new CUfunction(); cuModuleGetFunction(fClear, module, "clear"); fClearW = new CUfunction(); cuModuleGetFunction(fClearW, module, "clearW"); fEncode = new CUfunction(); cuModuleGetFunction(fEncode, module, "encode"); fDecode = new CUfunction(); cuModuleGetFunction(fDecode, module, "decode"); fShaper = new CUfunction(); cuModuleGetFunction(fShaper, module, "shaper"); fCalcB2d = new CUfunction(); cuModuleGetFunction(fCalcB2d, module, "calcB2d"); fCalcB1d = new CUfunction(); cuModuleGetFunction(fCalcB1d, module, "calcB1d"); fCalcWd = new CUfunction(); cuModuleGetFunction(fCalcWd, module, "calcWd"); fForward = new CUfunction(); cuModuleGetFunction(fForward, module, "forward"); fForwardW = new CUfunction(); cuModuleGetFunction(fForwardW, module, "forwardW"); devX = new CUdeviceptr(); cuMemAlloc(devX, n * Sizeof.FLOAT); devY = new CUdeviceptr(); cuMemAlloc(devY, m * Sizeof.FLOAT); devZ = new CUdeviceptr(); cuMemAlloc(devZ, n * Sizeof.FLOAT); devB1 = new CUdeviceptr(); cuMemAlloc(devB1, m * Sizeof.FLOAT); devB2 = new CUdeviceptr(); cuMemAlloc(devB2, n * Sizeof.FLOAT); devB1d = new CUdeviceptr(); cuMemAlloc(devB1d, m * Sizeof.FLOAT); devB1t = new CUdeviceptr(); cuMemAlloc(devB1t, m * Sizeof.FLOAT); devB2d = new CUdeviceptr(); cuMemAlloc(devB2d, n * Sizeof.FLOAT); devB2t = new CUdeviceptr(); cuMemAlloc(devB2t, n * Sizeof.FLOAT); devW = new CUdeviceptr(); cuMemAlloc(devW, m * Sizeof.POINTER); devWd = new CUdeviceptr(); cuMemAlloc(devWd, m * Sizeof.POINTER); CUdeviceptr[] wtmp = new CUdeviceptr[m]; for (int i = 0; i < m; ++i) { wtmp[i] = new CUdeviceptr(); cuMemAlloc(wtmp[i], n * Sizeof.FLOAT); // 係数 w をランダムに初期化する float[] winit = new float[n]; for (int j = 0; j < n; ++j) { winit[j] = (float)((Math.random() - 0.5) * 2.0 * 0.01); } cuMemcpyHtoD(wtmp[i], Pointer.to(winit), n * Sizeof.FLOAT); } cuMemcpyHtoD(devW, Pointer.to(wtmp), m * Sizeof.POINTER); CUdeviceptr[] wdtmp = new CUdeviceptr[m]; for (int i = 0; i < m; ++i) { wdtmp[i] = new CUdeviceptr(); cuMemAlloc(wdtmp[i], n * Sizeof.FLOAT); } cuMemcpyHtoD(devWd, Pointer.to(wdtmp), m * Sizeof.POINTER); initBias(); } @Override public void finalize() { cuMemFree(devX); cuMemFree(devY); cuMemFree(devZ); cuMemFree(devB1); cuMemFree(devB2); cuMemFree(devB1t); cuMemFree(devB1d); cuMemFree(devB2t); cuMemFree(devB2d); CUdeviceptr[] ptr = new CUdeviceptr[m]; cuMemcpyDtoH(Pointer.to(ptr), devW, m * Sizeof.POINTER); for (int i = 0; i < m; ++i) { cuMemFree(ptr[i]); } cuMemFree(devW); cuMemcpyDtoH(Pointer.to(ptr), devWd, m * Sizeof.POINTER); for (int i = 0; i < m; ++i) { cuMemFree(ptr[i]); } cuMemFree(devWd); } // CUDA スレッドの終了を待つ public void sync() { cuCtxSynchronize(); } // devX に入力データをセットする public void setInput(float[] in) { assert(in.length == n); cuMemcpyHtoD(devX, Pointer.to(in), n * Sizeof.FLOAT); } // devZ から出力値を取得する public float[] getOutput() { float[] out = new float[n]; cuMemcpyDtoH(Pointer.to(out), devZ, n * Sizeof.FLOAT); return out; } // devY から中間値を取得する public float[] getIntermediate() { float[] im = new float[m]; cuMemcpyDtoH(Pointer.to(im), devY, m * Sizeof.FLOAT); return im; } // 入力データを中間値にエンコードする public void encode() { // 中間値を 0 クリアする Pointer p = Pointer.to(Pointer.to(devY), Pointer.to(new int[]{m})); callKernel(fClear, p, m); // encode を実行する p = Pointer.to(Pointer.to(devW), Pointer.to(devX), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}), Pointer.to(devY)); callKernel2(fEncode, p, n, m); // 定数ノードの影響および sigmoid 関数を通す p = Pointer.to(Pointer.to(devB1), Pointer.to(new int[]{m}), Pointer.to(devY)); callKernel(fShaper, p, m); } // 中間値を出力値にデコードする public void decode() { // 出力値を 0 クリアする Pointer p = Pointer.to(Pointer.to(devZ), Pointer.to(new int[]{n})); callKernel(fClear, p, n); // decode を実行する p = Pointer.to(Pointer.to(devW), Pointer.to(devY), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}), Pointer.to(devZ)); callKernel2(fDecode, p, n, m); // 定数ノードの影響および sigmoid 関数を通す p = Pointer.to(Pointer.to(devB2), Pointer.to(new int[]{n}), Pointer.to(devZ)); callKernel(fShaper, p, n); } // 誤差領域を初期化する public void initDiff() { Pointer p = Pointer.to(Pointer.to(devB1d), Pointer.to(new int[]{m})); callKernel(fClear, p, m); p = Pointer.to(Pointer.to(devB2d), Pointer.to(new int[]{n})); callKernel(fClear, p, n); p = Pointer.to(Pointer.to(devWd), Pointer.to(new int[]{n}), Pointer.to(new int[]{m})); callKernel2(fClearW, p, n, m); } // 逆誤差散乱 public void updateDiff() { // B1t の 0 クリア Pointer p = Pointer.to(Pointer.to(devB1t), Pointer.to(new int[]{m})); callKernel(fClear, p, m); // B2t, B2d のアップデート p = Pointer.to(Pointer.to(devX), Pointer.to(devZ), Pointer.to(new int[]{n}), Pointer.to(devB2t), Pointer.to(devB2d)); callKernel(fCalcB2d, p, n); // B1t, B1d のアップデート p = Pointer.to(Pointer.to(devW), Pointer.to(devY), Pointer.to(devB2t), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}), Pointer.to(devB1t), Pointer.to(devB1d)); callKernel2(fCalcB1d, p, n, m); // Wd のアップデート p = Pointer.to(Pointer.to(devB1t), Pointer.to(devB2t), Pointer.to(devX), Pointer.to(devY), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}), Pointer.to(devWd)); callKernel2(fCalcWd, p, n, m); } // 定数バイアスの初期化 public void initBias() { Pointer p = Pointer.to(Pointer.to(devB1), Pointer.to(new int[]{m})); callKernel(fClear, p, m); p = Pointer.to(Pointer.to(devB2), Pointer.to(new int[]{n})); callKernel(fClear, p, n); } // 定数バイアス、結合係数の更新 public void updateBias(float rate) { // b2 のアップデート Pointer p = Pointer.to(Pointer.to(devB2d), Pointer.to(new float[]{rate}), Pointer.to(new int[]{n}), Pointer.to(devB2)); callKernel(fForward, p, n); // b1 のアップデート p = Pointer.to(Pointer.to(devB1d), Pointer.to(new float[]{rate}), Pointer.to(new int[]{m}), Pointer.to(devB1)); callKernel(fForward, p, m); // w のアップデート p = Pointer.to(Pointer.to(devWd), Pointer.to(new float[]{rate}), Pointer.to(new int[]{n}), Pointer.to(new int[]{m}), Pointer.to(devW)); callKernel2(fForwardW, p, n, m); } // トレーニング // pCnt: パターン数 // tCnt: トレーニング回数 // bSize: トレーニングバッチサイズ // rate: 学習率 public void training(float[][] data, int pCnt, int tCnt, int bSize, float rate) { int count = 0; for (int i = 0; i < pCnt * tCnt / bSize; ++i) { // 誤差を初期化する initDiff(); // バッチループ for (int j = 0; j < bSize; ++j) { int index = (j + i * bSize) % pCnt; setInput(data[index]); encode(); decode(); updateDiff(); if (count % 100 == 100 - 1) { System.out.print("."); if (count % 10000 == 10000 - 1) { System.out.println(count + 1); } } ++count; } // 結合係数と定数ノードを更新する updateBias(rate); } System.out.println("trained"); } }
気をつけるべき点は、二重配列でしょうか。
float[][] は float ** に変換されますが、デバイスメモリでは
float * の配列として定義する必要があります。
続いて CUDA コード。
// 以下、n: 入力値/出力値の個数, m: 中間ノードの個数とする。 const int NTHREAD = 16; #define IDX(x, y) (x + y * NTHREAD) #define IDXX(x) (x + NTHREAD * NTHREAD) // sigmoid 関数 __device__ float sigmoid(float x) { return 1.0f / (1.0f + exp(-x)); } // バッファを 0 クリアする extern "C" __global__ void clear(float *buf, int n) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i < n) { buf[i] = 0.0f; } } // W バッファを 0 クリアする extern "C" __global__ void clearW(float **buf, int n, int m) { int i = threadIdx.x + blockIdx.x * blockDim.x; int j = threadIdx.y + blockIdx.y * blockDim.y; if (i < n && j < m) { buf[j][i] = 0.0f; } } // 入力データのエンコード // 最初に out を 0 クリアしておくこと // 出力値を得るには更に shape を通す必要がある extern "C" __global__ void encode(float **w, float *in, int n, int m, float *out) { int i = threadIdx.x + blockIdx.x * blockDim.x; int j = threadIdx.y + blockIdx.y * blockDim.y; int i2 = blockDim.x / 2; extern __shared__ float s[]; // 入力領域にキャッシュする if (threadIdx.y == 0 && i < n) { s[IDXX(threadIdx.x)] = in[i]; } __syncthreads(); // 結合係数をかける if (i < n && j < m) { s[IDX(threadIdx.x, threadIdx.y)] = w[j][i] * s[IDXX(threadIdx.x)]; } __syncthreads(); // 縮約する while (i2 > 0) { if (threadIdx.x < i2 && i + i2 < n && j < m) { s[IDX(threadIdx.x, threadIdx.y)] += s[IDX(threadIdx.x + i2, threadIdx.y)]; } __syncthreads(); i2 = i2 / 2; } // 結果を更新する if (threadIdx.x == 0 && j < m) { atomicAdd(&out[j], s[IDX(0, threadIdx.y)]); } __syncthreads(); } // 中間データのデコード // 最初に out を 0 クリアしておくこと // 出力値を得るには更に shape を通す必要がある extern "C" __global__ void decode(float **w, float *in, int n, int m, float *out) { int i = threadIdx.x + blockIdx.x * blockDim.x; int j = threadIdx.y + blockIdx.y * blockDim.y; int j2 = blockDim.y / 2; extern __shared__ float s[]; // 入力領域にキャッシュする if (threadIdx.x == 0 && j < m) { s[IDXX(threadIdx.y)] = in[j]; } __syncthreads(); // 結合係数をかける if (i < n && j < m) { s[IDX(threadIdx.y, threadIdx.x)] = w[j][i] * s[IDXX(threadIdx.y)]; } __syncthreads(); // 縮約する while (j2 > 0) { if (threadIdx.y < j2 && j + j2 < m && i < n) { s[IDX(threadIdx.y, threadIdx.x)] += s[IDX(threadIdx.y + j2, threadIdx.x)]; } __syncthreads(); j2 = j2 / 2; } // 結果を更新する if (threadIdx.y == 0 && i < n) { atomicAdd(&out[i], s[IDX(0, threadIdx.x)]); } __syncthreads(); } // 定数ノードを考慮し、sigmoid 関数で (0,1) にマッピングする extern "C" __global__ void shaper(float *b, int n, float *out) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i < n) { out[i] = sigmoid(out[i] + b[i]); } } // 出力に関する定数ノードの更新値を計算する extern "C" __global__ void calcB2d(float *x, float *z, int n, float *b2t, float *b2d) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i < n) { b2t[i] = x[i] - z[i]; b2d[i] += b2t[i]; } } // 入力に関する定数ノードの更新値を計算する extern "C" __global__ void calcB1d(float **w, float *y, float *b2t, int n, int m, float *b1t, float *b1d) { int i = threadIdx.x + blockIdx.x * blockDim.x; int j = threadIdx.y + blockIdx.y * blockDim.y; int i2 = blockDim.x / 2; extern __shared__ float s[]; // 入力バッファへのロード if (threadIdx.y == 0 && i < n) { s[IDXX(threadIdx.x)] = b2t[i]; } __syncthreads(); // 結合係数の作用の計算 if (i < n && j < m) { s[IDX(threadIdx.x, threadIdx.y)] = w[j][i] * s[IDXX(threadIdx.x)]; } __syncthreads(); // 縮約 while (i2 > 0) { if (threadIdx.x < i2 && i + i2 < n && j < m) { s[IDX(threadIdx.x, threadIdx.y)] += s[IDX(threadIdx.x + i2, threadIdx.y)]; } __syncthreads(); i2 = i2 / 2; } // 計算値を保存 if (threadIdx.x == 0 && j < m) { float tmp = s[IDX(0, threadIdx.y)] * y[j] * (1.0f - y[j]); atomicAdd(&b1t[j], tmp); atomicAdd(&b1d[j], tmp); } __syncthreads(); } // w の更新値を計算する extern "C" __global__ void calcWd(float *b1t, float *b2t, float *x, float *y, int n, int m, float **wd) { int i = threadIdx.x + blockIdx.x * blockDim.x; int j = threadIdx.y + blockIdx.y * blockDim.y; if (i < n && j < m) { wd[j][i] += b1t[j] * x[i] + b2t[i] * y[j]; } } // b1, b2 を更新する // rate は学習率 extern "C" __global__ void forward(float *bd, float rate, int n, float *b) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i < n) { b[i] += rate * bd[i]; } } // w を更新する // rate は学習率 extern "C" __global__ void forwardW(float **wd, float rate, int n, int m, float **w) { int i = threadIdx.x + blockIdx.x * blockDim.x; int j = threadIdx.y + blockIdx.y * blockDim.y; if (i < n && j < m) { w[j][i] += rate * wd[j][i]; } }
高速化のために shared memory を使いまくっています。
逆に言うと、キャッシュしないと GPU を使っても
あまり速くなりません。
最後にアプリケーションを起動するための Java コード。
import java.io.IOException; public class AutoEncoderApp { private static final int N_NODE = 28 * 28; private static final int N_INTERMEDIATE = 400; private static final int N_PCNT = 6000; private static final int N_TCNT = 10; private static final int BATCHSIZE = 20; private static final float LEARNRATE = 0.01f; AutoEncoder ae; float[][] images; // 実行メソッド public void run() { // トレーニングを実施 ae.training(images, N_PCNT, N_TCNT, BATCHSIZE, LEARNRATE); ae.sync(); // トレーニング結果を取り出し float[][] results = new float[N_PCNT][]; for (int i = 0; i < N_PCNT; ++i) { ae.setInput(images[i]); ae.encode(); ae.decode(); ae.sync(); results[i] = ae.getOutput(); } // トレーニング結果を表示 new Visualizer(28, 28, 10, 10, 1.0f).dispDataImage(results, N_PCNT); } // コンストラクタ public AutoEncoderApp(String fname) throws IOException { ae = new AutoEncoder(N_NODE, N_INTERMEDIATE); images = DATA.readImage(fname, N_PCNT); } // メイン関数 public static void main(String[] args) throws IOException { new AutoEncoderApp("train-images-idx3-ubyte").run(); } }
DATA クラスや Visualizer クラスは以前紹介した Java AI プログラミング本
で定義されているクラスです。ファイルからの MNIST データの読み込み、
結果のグラフィック表示を実行します。
Utils は ptx ファイルがなければ nvcc を起動して作成するクラスです。
パターン数 6000 くらいにしても Core i7 6950X では全く進まなかった計算が
GTX 1080 使用で 30 秒くらいで返ってくるようになりました。
JCuda は Hybrid プログラミングとなりますが、慣れてしまえば C++ で
フルに書くよりも効率的に書けそうです。今回のようなちょっとした実験
には向いていると感じました。