PyTorchについて概要を把握したいと思ってましたが。。。
今後必要になりそうなので、やっと調べ始めました。
公式ドキュメントを読んでも、まぁ頭に入らないので、pytorchを使用しているプログラムから学んで行くことにしよう。
と思っていましたが、結論いうと、
中途半端な知識ではほとんどわからない、ということがわかりました。。
これ以降はだたの駄文、汚物じゃ!!
今年かけてpytorchならびにDeeplearningについて学習し、
年末にこの汚物を修正し、神へと昇華させる所存でございます。
お勉強用プログラム
Deep learning で一番ホットな企業であるNVIDIAのサンプルをお勉強用としてます。
そのNVIDIAのデバイスを使用した、JetBotというものがあり、それ用に提供されているCollision Avoidanceのデモプログラムを見ています。
コードは、↓Githubから取ってきてます。
- https://github.com/NVIDIA-AI-IOT/jetbot
- notebooks/collision_avoidance/live_demo.ipynb
さっそくレッツ&ゴー!!
[1] import torch [2] import torchvision
[1][2] importしている"torch" と "torchvision"は何者ぞ?
公式ドキュメントによると、
- torch: Tensorデータやその算術とかのパッケージ
- torchvision: Computer Visionのためのパッケージ。具体的には下記のようなもの
- popular datasets
- model architectures
- common image transformations
torchvisionの"vision"は、computer visionのvisionか、なるほど。
"torch"単品のほうは、computer visionに限らず、汎用的なものが詰まっているのか、なるほど。
[3] model = torchvision.models.alexnet(pretrained=False) [4] model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)
[3] torchvisionパッケージで、かの有名なalexnetを召喚している。(torchvisionドキュメント箇所)
[4] よくわからん。。一つずつみていこうか。
- model.classifier[6] :
- マジックナンバー6ってなんだ?
- ?今のところよくわかっていないけど、最終レイヤーを弄っているっぽい。
- torch.nn.Linear:
- torch.nn:
- グラフのためのベーシックなブロックとのこと。
- ところで、torch.nnの”nn”って何の略なんだ?よくわからん。
- torch.nn.Linear
- 入力データを線形変換(Linear transformation)するもの。全然わかんね。。
- torch.nn:
[5] model.load_state_dict(torch.load('best_model.pth'))
[5] モデルの重みデータを読み込んでいる
[6] device = torch.device('cuda')
[7] model = model.to(device)
[6] torch.device: cuda使います。と宣言。
[7] 調べてないけど、modelをdevice(cuda,,というかGPU?)へロードします宣言。
[8]
import torch.nn.functional as F
import time
def update(change):
global robot
x = change['new'] # <--- camera input
x = preprocess(x) # <--- cmera input preprocessed
y = model(x) # <--- Alextnetにかけた結果がy
y = F.softmax(y, dim=1) # <--- yをsoftmax関数にかけて、0〜100%の値に変換する。
prob_blocked = float(y.flatten()[0]) #<---flatten後の最初の要素が、blocked状態かどうかを示すものらしい。
if prob_blocked < 0.5: <--- 50%より小さいと直進
robot.forward(speed_slider.value)
else: <--- 50%以上だと、Blockedとみなし、左へ曲がる。
robot.left(speed_slider.value)
time.sleep(0.001)
update({'new': camera.value}) # we call the function once to initialize"
[8] プログラム中にコメントとして記載。