2020-07-19
Flux.jl - MNISTデータセットで学習
基本的な使い方がわかったので,次はMNISTデータセットを使用して学習してみる.
バージョン情報
Flux v0.90
MNISTで学習するために
最初に,機械学習をするために必要なデータ分割やミニバッチをシャッフルしたりするコードを一つ一つ確認していく.次に損失関数や正解率を求めるコードについて.最後に実行結果を示す.
データ準備
FluxではMNISTデータセットをダウンロードできる.画像は28x28の配列のVector
.ラベルはInt64
のVector
.学習データn_traindata
は50000,ミニバッチのサイズbatch_size
は512にしておく.
# 1. 学習データの用意
mnist_images = Flux.Data.MNIST.images() # 学習データ
mnist_labels = Flux.Data.MNIST.labels() # 教師データ 60000-element Array{Int64,1}
n_data = length(mnist_labels)
n_traindata = 50000
batch_size = 512
データのonehot化
MNISTの画像のラベルは 0 ~ 9なので,それをonehotに変換したり元の数字に戻したりする方法.
Flux.onecold
onecold
を使えばモデルから出力されたベクトルの最大値のインデックスに対応する第二引数のラベルに変換できる.予測がなんの数字に対応するか調べるときに使う.第二引数は数字である必要はない.例えば["A", "B", "C"
]`なども可能.
Flux.onecold(rand(10, 5), 0:9)
Flux.onehot
onehotに変換することができる.
Flux.onehot(0, 0:9)
このように出力される.
10-element Flux.OneHotVector:
true
false
false
false
false
false
false
false
false
false
ミニバッチの配列を作る
ミニバッチ作成にはIterators.partition
が非常に便利で,任意のサイズに配列を区切ることができる.randpermで1~ntraindataまでの数字をシャッフルする.これをbatchsizeのパーティションで区切る.
n_traindata = 10
batch_size = 3
batch_indices = Iterators.partition(randperm(n_traindata), batch_size)
これでbatch_indices
を使ってepoch
ループ内でfor
すればよい.batch_indices |> collect
するとこの例では以下のようになる.batch_size = 3
で区切られている.
4-element Array{Array{Int64,1},1}:
[6, 1, 3]
[7, 10, 2]
[9, 5, 4]
[8]
今回はこのように学習データのindices
を分割して元のデータセットの配列をスライスして使っていく.
batch_indices = Iterators.partition(randperm(n_traindata), batch_size)
for (index, batch_indices) in enumerate(batch_indices)
# ここに各ミニバッチを使った計算
end
損失関数 : crossentropy
Flux.crossentropy
を使う.第二引数の教師データはonehotで渡す.(Pytorchではindicesで渡していたが,Fluxでは異なるので注意)
using Flux: onehotbatch, crossentropy
x = rand(10, 5) |> softmax
y = onehotbatch([0, 1, 1, 3, 6], 0:9)
crossentropy(x, y)
xは予測した数字の分類の確率分布を模擬している.softmaxを施しているので縦方向の総和が1になる.横方向はバッチ.
julia> x = rand(10, 5) |> softmax
10×5 Array{Float64,2}:
0.156741 0.119001 0.0799983 0.0693452 0.0596079
0.060697 0.126603 0.079851 0.14257 0.131742
0.0665962 0.128753 0.133022 0.0803353 0.0945145
0.0961805 0.121828 0.128254 0.0708888 0.0777128
0.0803562 0.104694 0.0968524 0.120088 0.0712105
0.118641 0.0511592 0.104646 0.0737453 0.146321
0.144299 0.0887068 0.0681773 0.138676 0.0747795
0.0733639 0.0988206 0.116772 0.0833416 0.124023
0.0945037 0.105863 0.116514 0.137138 0.0789032
0.108621 0.0545716 0.0759122 0.0838719 0.141185
教師データはonehot.
julia> y = Flux.onehotbatch([0, 1, 1, 3, 6], 0:9)
10×5 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
true false false false false
false true true false false
false false false false false
false false false true false
false false false false false
false false false false false
false false false false true
false false false false false
false false false false false
false false false false false
このままcrossentropy
に入力すれば計算可能.
julia> crossentropy(x, y)
2.337460868704356
実際には損失関数をこのように定義しておく.xがモデルが予測する確率分布の配列.yが教師データ.
loss_func(x, y) = crossentropy(x, Flux.onehotbatch(y, 0:9))
パラメータの更新
各層のTracked
な配列を更新する関数を定義する.
Optimizer = Union{ADAM, ADAGrad, RMSProp}
function Flux.Tracker.update!(optimizer::Optimizer, ps::Flux.Tracker.Params)
for p in ps.order
update!(optimizer, p, grad(p))
end
end
ネットワークモデル
適当.途中はDense層を使う.入力は画像をflatten
で1次元ベクトル化したもの.出力は分類問題なので0~9でサイズが10のベクトル.最後にsoftmax層を加えることで最後のベクトルの総和が1になるので確率分布のように扱える.
model = Chain(Dense(features, 512, relu),
Dense(512, 256, relu),
Dense(256, n_categories),
softmax)
コード全体
Pytorchっぽく実装.
# train!
using Printf
using Flux
using Flux: back!, crossentropy, onecold, onehotbatch
using Flux.Tracker: grad, update!
using Random
using CuArrays
# 28 x 28の行列をベクトルにする関数
to_vector = img -> collect(Iterators.flatten(Float32.(img)))
f = open("mnist-loss.csv", "w")
# optimizerの型をまとめておく.ADAMなどは抽象型が実装されていない
Optimizer = Union{ADAM, ADAGrad, RMSProp}
function Flux.Tracker.update!(optimizer::Optimizer, ps::Flux.Tracker.Params)
for p in ps.order
update!(optimizer, p, grad(p))
end
end
mnist_images = Flux.Data.MNIST.images() # 学習データ
mnist_labels = Flux.Data.MNIST.labels() # 教師データ 60000-element Array{Int64,1}
n_data = length(mnist_labels)
n_traindata = 50000
n_testdata = n_data - n_traindata
test_indices = n_traindata+1:n_data
batch_size = 512
features = length(mnist_images[1])
n_categories = 10
model = Chain(Dense(features, 512, relu),
Dense(512, 256, relu),
Dense(256, n_categories),
softmax) |> gpu
optimizer = ADAM()
loss_func(x, y) = crossentropy(x, Flux.onehotbatch(y, 0:9) |> gpu)
accuracy_func(x, y) = sum(onecold(x |> cpu, 0:9) .== y)
@printf(f, "epoch, train_loss, train_acc, test_loss, test_acc\n")
n_epochs = 20
for epoch in 1:n_epochs
println("epoch: $epoch")
# train
loss_total = 0.0
acc_total = 0.0
batch_indices = Iterators.partition(randperm(n_traindata), batch_size)
for (index, batch_indices) in enumerate(batch_indices)
train_in = reduce(hcat, map(to_vector, mnist_images[batch_indices])) |> gpu
train_out = mnist_labels[batch_indices]
predict = model(train_in) # forward
loss = loss_func(predict, train_out) # 損失
acc = accuracy_func(predict, train_out) # 正解数
loss_total += loss.data
acc_total += acc
back!(loss) # backward
update!(optimizer, params(model)) # パラメータ更新
if index % 5 == 0
println("index_batch: $((index-1)*batch_size + length(batch_indices)) / $(n_traindata)")
@printf(stdout, "% .2e\n", loss.data)
end
end
# test
loss_total_test = 0.0
acc_total_test = 0.0
for (index, batch_indices) in enumerate(Iterators.partition(test_indices, 512))
test_in = reduce(hcat, map(to_vector, mnist_images[batch_indices])) |> gpu
test_out = mnist_labels[batch_indices]
predict = model(test_in)
loss = loss_func(predict, test_out) # 損失
acc = accuracy_func(predict, test_out) # 正解数
loss_total_test += loss.data
acc_total_test += acc
end
@printf(f, "%d, % .e, % .e, % .e, % .e\n", epoch, loss_total/n_traindata, acc_total/n_traindata, loss_total_test/n_testdata, acc_total_test/n_testdata)
flush(f)
end
close(f)
学習結果
各epochにおける学習データとテストデータそれぞれの正解率をプロットした.trainデータの精度はほぼ1になっている.testデータも0.975ぐらいの精度まで学習できている.
using CSV
using Plots
pyplot()
data = CSV.read("mnist-loss.csv")
plot(data[:, 1], data[:, 3], label="train", xlabel="epoch", ylabel="accuracy rate", marker=:circle, legend=:right, markersize=8, dpi=300)
plot!(data[:, 1], data[:, 5], label="test", marker=:circle, legend=:right, markersize=8)
savefig("mnist-acc.png")
カテゴリー