Soyukke.Dev

2020-07-19

Flux.jl - MNISTデータセットで学習

基本的な使い方がわかったので,次はMNISTデータセットを使用して学習してみる.

バージョン情報

  • Flux v0.90

MNISTで学習するために

最初に,機械学習をするために必要なデータ分割やミニバッチをシャッフルしたりするコードを一つ一つ確認していく.次に損失関数や正解率を求めるコードについて.最後に実行結果を示す.

データ準備

FluxではMNISTデータセットをダウンロードできる.画像は28x28の配列のVector.ラベルはInt64Vector.学習データn_traindataは50000,ミニバッチのサイズbatch_sizeは512にしておく.

# 1. 学習データの用意
mnist_images = Flux.Data.MNIST.images() # 学習データ
mnist_labels = Flux.Data.MNIST.labels() # 教師データ 60000-element Array{Int64,1}
n_data = length(mnist_labels)
n_traindata = 50000
batch_size = 512

データのonehot化

MNISTの画像のラベルは 0 ~ 9なので,それをonehotに変換したり元の数字に戻したりする方法.

Flux.onecold

onecoldを使えばモデルから出力されたベクトルの最大値のインデックスに対応する第二引数のラベルに変換できる.予測がなんの数字に対応するか調べるときに使う.第二引数は数字である必要はない.例えば["A", "B", "C"]`なども可能.

 Flux.onecold(rand(10, 5), 0:9)

Flux.onehot

onehotに変換することができる.

Flux.onehot(0, 0:9)

このように出力される.

10-element Flux.OneHotVector:
  true
 false
 false
 false
 false
 false
 false
 false
 false
 false

ミニバッチの配列を作る

ミニバッチ作成にはIterators.partitionが非常に便利で,任意のサイズに配列を区切ることができる.randpermで1~ntraindataまでの数字をシャッフルする.これをbatchsizeのパーティションで区切る.

n_traindata = 10
batch_size = 3
batch_indices = Iterators.partition(randperm(n_traindata), batch_size)

これでbatch_indicesを使ってepochループ内でforすればよい.batch_indices |> collectするとこの例では以下のようになる.batch_size = 3で区切られている.

4-element Array{Array{Int64,1},1}:
 [6, 1, 3] 
 [7, 10, 2]
 [9, 5, 4] 
 [8]

今回はこのように学習データのindicesを分割して元のデータセットの配列をスライスして使っていく.

batch_indices = Iterators.partition(randperm(n_traindata), batch_size)
for (index, batch_indices) in enumerate(batch_indices)
    # ここに各ミニバッチを使った計算
end

損失関数 : crossentropy

Flux.crossentropyを使う.第二引数の教師データはonehotで渡す.(Pytorchではindicesで渡していたが,Fluxでは異なるので注意)

using Flux: onehotbatch, crossentropy

x = rand(10, 5) |> softmax
y = onehotbatch([0, 1, 1, 3, 6], 0:9)
crossentropy(x, y)

xは予測した数字の分類の確率分布を模擬している.softmaxを施しているので縦方向の総和が1になる.横方向はバッチ.

julia> x = rand(10, 5) |> softmax
10×5 Array{Float64,2}:
 0.156741   0.119001   0.0799983  0.0693452  0.0596079
 0.060697   0.126603   0.079851   0.14257    0.131742
 0.0665962  0.128753   0.133022   0.0803353  0.0945145
 0.0961805  0.121828   0.128254   0.0708888  0.0777128
 0.0803562  0.104694   0.0968524  0.120088   0.0712105
 0.118641   0.0511592  0.104646   0.0737453  0.146321
 0.144299   0.0887068  0.0681773  0.138676   0.0747795
 0.0733639  0.0988206  0.116772   0.0833416  0.124023
 0.0945037  0.105863   0.116514   0.137138   0.0789032
 0.108621   0.0545716  0.0759122  0.0838719  0.141185

教師データはonehot.

julia> y = Flux.onehotbatch([0, 1, 1, 3, 6], 0:9)
10×5 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
  true  false  false  false  false
 false   true   true  false  false
 false  false  false  false  false
 false  false  false   true  false
 false  false  false  false  false
 false  false  false  false  false
 false  false  false  false   true
 false  false  false  false  false
 false  false  false  false  false
 false  false  false  false  false

このままcrossentropyに入力すれば計算可能.

julia> crossentropy(x, y)
2.337460868704356

実際には損失関数をこのように定義しておく.xがモデルが予測する確率分布の配列.yが教師データ.

loss_func(x, y) = crossentropy(x, Flux.onehotbatch(y, 0:9))

パラメータの更新

各層のTrackedな配列を更新する関数を定義する.

Optimizer = Union{ADAM, ADAGrad, RMSProp} 

function Flux.Tracker.update!(optimizer::Optimizer, ps::Flux.Tracker.Params)
    for p in ps.order
        update!(optimizer, p, grad(p))
    end
end

ネットワークモデル

適当.途中はDense層を使う.入力は画像をflattenで1次元ベクトル化したもの.出力は分類問題なので0~9でサイズが10のベクトル.最後にsoftmax層を加えることで最後のベクトルの総和が1になるので確率分布のように扱える.

model = Chain(Dense(features, 512, relu),
                        Dense(512, 256, relu),
                        Dense(256, n_categories),
                        softmax)

コード全体

Pytorchっぽく実装.

# train!
using Printf
using Flux
using Flux: back!, crossentropy, onecold, onehotbatch
using Flux.Tracker: grad, update!
using Random
using CuArrays

# 28 x 28の行列をベクトルにする関数
to_vector = img -> collect(Iterators.flatten(Float32.(img)))

f = open("mnist-loss.csv", "w")

# optimizerの型をまとめておく.ADAMなどは抽象型が実装されていない
Optimizer = Union{ADAM, ADAGrad, RMSProp} 

function Flux.Tracker.update!(optimizer::Optimizer, ps::Flux.Tracker.Params)
    for p in ps.order
        update!(optimizer, p, grad(p))
    end
end

mnist_images = Flux.Data.MNIST.images() # 学習データ
mnist_labels = Flux.Data.MNIST.labels() # 教師データ 60000-element Array{Int64,1}
n_data = length(mnist_labels)
n_traindata = 50000
n_testdata = n_data - n_traindata
test_indices = n_traindata+1:n_data
batch_size = 512

features = length(mnist_images[1])
n_categories = 10
model = Chain(Dense(features, 512, relu),
                        Dense(512, 256, relu),
                        Dense(256, n_categories),
                        softmax) |> gpu
optimizer = ADAM()
loss_func(x, y) = crossentropy(x, Flux.onehotbatch(y, 0:9) |> gpu)
accuracy_func(x, y) = sum(onecold(x |> cpu, 0:9) .== y)

@printf(f, "epoch, train_loss, train_acc, test_loss, test_acc\n")
n_epochs = 20
for epoch in 1:n_epochs
    println("epoch: $epoch")
    # train
    loss_total = 0.0
    acc_total = 0.0
    batch_indices = Iterators.partition(randperm(n_traindata), batch_size)
    for (index, batch_indices) in enumerate(batch_indices)
        train_in = reduce(hcat, map(to_vector, mnist_images[batch_indices])) |> gpu
        train_out = mnist_labels[batch_indices]
        predict = model(train_in) # forward
        loss = loss_func(predict, train_out) # 損失
        acc = accuracy_func(predict, train_out) # 正解数
        loss_total += loss.data
        acc_total += acc
        back!(loss) # backward
        update!(optimizer, params(model)) # パラメータ更新
        if index % 5 == 0
            println("index_batch: $((index-1)*batch_size + length(batch_indices)) / $(n_traindata)")
            @printf(stdout, "% .2e\n",  loss.data)
        end
    end
    
    # test
    loss_total_test = 0.0
    acc_total_test = 0.0
    for (index, batch_indices) in enumerate(Iterators.partition(test_indices, 512))
        test_in = reduce(hcat, map(to_vector, mnist_images[batch_indices])) |> gpu
        test_out = mnist_labels[batch_indices]
        predict = model(test_in)
        loss = loss_func(predict, test_out) # 損失
        acc = accuracy_func(predict, test_out) # 正解数
        loss_total_test += loss.data
        acc_total_test += acc
    end
    @printf(f, "%d, % .e, % .e, % .e, % .e\n", epoch, loss_total/n_traindata, acc_total/n_traindata, loss_total_test/n_testdata, acc_total_test/n_testdata)
    flush(f)
end
close(f)

学習結果

各epochにおける学習データとテストデータそれぞれの正解率をプロットした.trainデータの精度はほぼ1になっている.testデータも0.975ぐらいの精度まで学習できている.

using CSV
using Plots
pyplot()
data = CSV.read("mnist-loss.csv")
plot(data[:, 1], data[:, 3], label="train", xlabel="epoch", ylabel="accuracy rate", marker=:circle, legend=:right, markersize=8, dpi=300)
plot!(data[:, 1], data[:, 5], label="test", marker=:circle, legend=:right, markersize=8)
savefig("mnist-acc.png")

カテゴリー