読者です 読者をやめる 読者になる 読者になる

えっ...ちょま...

あと5億年ほしい

【機械学習】手書きの機械学習で逆行列を求めてみた【時々失敗】

MachineLearning

こんにちはみなさん。

一度、出来合いのモデルではなく素手で機械学習をやってみようと思い勃ち、簡単なタスクをやったので覚え書き。

暗号脆弱すぎ

やあみんな。hoge国のスパイである俺piyoはfuga国の暗号解読にいそしんでいる。そしてついに今、fuga国基地の暗号マシーンルームに潜入している!

目の前には巨大な機械がそびえる。こいつにデータを食わせると暗号化して出力するようだ。破壊してしまいたいのは山々だが、新しい暗号マシーンを作られるのが関の山だ。ここはこっそり暗号を解読しておけば、今後のfuga国機密情報が筒抜けだぜ!聞くところによると「行列」とかいうメチャスゴイテクをつかった暗号らしい。えっ...学校で習った...?

 

・・・。

<設定>

ベクトルx(パスワード)が定行列Aを左から掛けることによりAxに暗号化されています。(暗号というか単なる変換です^^;)このとき、元のベクトルxを求めてください。ただし、piyo氏は今暗号マシーンに好きなだけデータを入力することができるので、xAxの組を好きなだけ得ることができます。

 

あえて機械学習で!

連立方程式立てて解けとか言わないで!)

Axを入力、xを教師信号とすれば、うまく収束した場合、任意の x に対して Axx に復号するニューラルネットが得られます。重み行列をWとすれば出力 yy = WAx です!

以下コードです↓

import numpy as np
import matplotlib.pyplot as plt


# 定数たち
dim = 4                                                       # [tex:x]の次元
train_cycle = 500000                                 # 学習回数 50万回まわせば大体おk
A = np.random.randn(dim, dim) * 100       # fuga国の暗号行列(未知)。ここではランダムに生成
eta = 0.0001                                              # 学習率
eps = 0.0001                                             # loss(損失関数)のoverflowを抑える定数
step_plot, loss_plot = [], []                         # lossのプロット用


# 本体

# 重み行列
w = np.zeros((dim, dim))

# ∂(loss) / ∂ w[i][j] : 損失関数をw[i][j]で微分
def diff(w, i, j, x, x_, y):
    return eps * 2 * (y[j] - x[j]) * x_[i]  # overflowするのでeps(≒0)で抑える

# 学習
def train():
    for step in range(train_cycle):
        x = np.random.randn(dim)           # ランダムにデータを生成 (x)
        x_ = x.dot(A)                      # 暗号マシーンにデータを投入すると暗号化された (Ax)
        y = x_.dot(w)                      # x_を入力とし中間層に通す (Wx_)
        for i in range(dim):
            for j in range(dim):
            w[i, j] -= eta * diff(w, i, j, x, x_, y)  # 重み行列の各成分を更新
            if step % 1000 == 0:                      # 1000stepごとにlossの値をプロットデータに追加
                sq_loss = sum([ (x[k] - y[k])**2 for k in range(dim) ]) / (dim*dim)
                step_plot.append(step)
                loss_plot.append(sq_loss)
                if step % 10000 == 0:                 # 10000stepごとにlossの値を表示
                    print('square loss = {}'.format(sq_loss))

# 学習が終わったらテスト
def test():
    x = np.random.randn(dim) # テスト用ランダムデータを生成
    x_ = x.dot(A)
    y = x_.dot(w)
    # 各値を出力
    print('input x :\n{}'.format(x))
    print('output y :\n{}'.format(y))
    print('weight W :\n{}'.format(w))
    print('prediction W^-1 :\n{}'.format(np.linalg.inv(w)))
    print('target A :\n{}'.format(A))
    # Aの行列式を表示 0なら逆行列が存在
    print('det A : {}'.format(np.linalg.det(A)))
    print('A * W =\n{}'.format(A.dot(w)))
    # lossをプロット
    plt.plot(step_plot, loss_plot)
    plt.title('loss')
    plt.xlabel('step')
    plt.ylabel('loss')
    plt.ylim(0, loss_plot[11]) # 11000step程度のloss値をy軸最大値にしておく
    plt.show()[f:id:y-bros:20160925162126p:plain]


train()
test()

実行結果

$ python3 inverseMatrix.py
square loss = 42610.20118829254
square loss = 20323.637924723604
square loss = 30288.688287093377
square loss = 10875.818035001463
square loss = 2811.707447080907
~省略~
square loss = 1.8918150634847766e-12
square loss = 2.9557316483293136e-10
square loss = 7.24908659239759e-10
square loss = 9.553071127287272e-10
square loss = 9.84578463347026e-11
input x :
[-0.21449368  0.1321071  -0.09837691 -0.32187035]
output y :
[-0.21449705  0.13209922 -0.09837548 -0.32187667]
weight W :
[[ 0.00498076  0.00270704 -0.00706897 -0.00828758]
 [ 0.00447457  0.00434026  0.00032405  0.00126585]
 [ 0.00765862 -0.00170078  0.00795789  0.00880727]
 [ 0.01531947 -0.00331515  0.00152671  0.01104624]]
prediction W^-1 :
[[  84.44478897  -18.83191587   74.62494532    6.014635  ]
 [ -50.67046067  224.09681451  -49.48490054  -24.24189844]
 [  64.15737576  -44.0591449   205.70874865 -110.8297228 ]
 [-141.18642187   99.46140533 -146.77590662   90.22966444]]
target A :
[[  84.44299633  -18.8280089    74.62272698    6.01554465]
 [ -50.67012216  224.09607674  -49.48448164  -24.24207021]
 [  64.15669835  -44.05766853  205.70791038 -110.82937907]
 [-141.18829547   99.46548873 -146.77822514   90.23061517]]
det A : 241339578.07193252
A * W =
[[  1.00000550e+00   1.28617876e-05  -2.32635279e-06   1.03129695e-05]
 [ -1.03842653e-06   9.99997571e-01   4.39297419e-07  -1.94745222e-06]
 [  2.07801371e-06   4.86023248e-06   9.99999121e-01   3.89708110e-06]
 [  5.74745646e-06   1.34426325e-05  -2.43141207e-06   1.00001078e+00]]

f:id:y-bros:20160925162126p:plain

おおー!
lossがちゃんと0に収束してます!
outputがほとんどinputを復元しています。
既におわかりかと思いますが、y = WAxですから、WW = A^{-1}に収束します。
ご覧の通り、W^{-1} ≒ AAW ≒ Eです。

ただしこんなときも...

$ python3 inverseMatrix.py
square loss = 291125.0812840772
square loss = 3366.8692509737048
square loss = 2063.704714359924
square loss = 288.6723361615006
square loss = 215.98748429802268
~省略~
square loss = 48.363259443737086
square loss = 1208.6314956931133
square loss = 95.96413476733537
square loss = 730.6723663736443
square loss = 802.7810041874261
input x :
[ 0.80655357 -0.91680023 -0.50208048 -0.22126768]
output y :
[ -5.3675719  -31.59692307 -28.95687437  20.9790567 ]
weight W :
[[-0.32838199 -1.86566602 -1.74597015  1.27611294]
 [-0.85447543 -4.83238558 -4.53401897  3.31281086]
 [ 0.10085938  0.6028011   0.54763324 -0.40815493]
 [-0.55702683 -3.17212337 -2.97067206  2.18198039]]
prediction W^-1 :
[[ 2233.12204066 -1039.85370575   685.30631078   400.93571354]
 [   54.89572853   -48.22366987    83.74420176    56.7756508 ]
 [ -861.00385318   374.37897063  -310.04905227  -122.8501461 ]
 [ -522.33062166   174.13506773  -125.42418616    18.09540679]]
target A :
[[ -88.1011476    25.10448041  -37.61105938   17.39953604]
 [-204.69495009   70.87440723    2.8977706    13.8834304 ]
 [-227.54911123   83.75512351 -112.76710261  -18.1842908 ]
 [-100.07514206  -19.59228771    6.08226371   87.8647652 ]]
det A : -15737682.92898789
A * W =
[[ -6.00576802 -34.81267497 -32.28759859  24.05596633]
 [ -0.78348006  -2.89322576  -3.61083745   2.69026462]
 [  1.91185277   9.50027302   9.81118737  -6.56479998]
 [  1.2744246    6.33280024   5.87346167  -3.37603917]]

f:id:y-bros:20160925162430p:plain

  • はじめの方のstepでlossが急落してしまうと、あとの方の学習が進まない。(どこか局所的な極小値に落ち込んでしまい脱出できない場合か)
  • lossがうまいぐあいに振動状態に陥ってしまうと収束しない。

lossをstep数に応じてうまくコントロールするしくみが必要?

\det A = 0 の場合

Aのどっか1つの行か列を0だけにしてしまえば\det A = 0になります。
このときA逆行列が存在しません。
コードでAを定義した直後に

A[0] = [0] * dim

とすればAの1行目が0になります。

$ python3 inverseMatrix.py
square loss = 772671.1467702463
square loss = 69.4019338825569
square loss = 6.982992110638545
square loss = 1.286551345557311
square loss = 10.424542218341191
~省略~
square loss = 0.05419456847329727
square loss = 1.4819165655932056e-05
square loss = 0.015289836088997822
square loss = 0.006291635570690296
square loss = 0.06200217512746582
input x :
[ 1.77477075 -0.89664841 -0.57233113  0.25513553]
output y :
[ 0.01964648 -0.89664841 -0.57233113  0.25513553]
weight W :
[[-1.07070715 -1.42838061  0.07365913 -1.55981773]
 [-5.67462433 -7.56369352  0.41363814 -8.27081493]
 [-2.27615535 -3.02405876  0.16092615 -3.32045091]
 [-4.21507011 -5.61036664  0.30836188 -6.15556444]]
prediction W^-1 :
[[ 328.69825358 -110.17913259 -258.12609371  203.98728442]
 [-138.16057798    7.42473546  113.43299254  -36.15460113]
 [-207.61205524   53.06209846  -19.665432     -8.0791502 ]
 [-109.55523851   71.33701762   72.38244995 -107.29664931]]
target A :
[[   0.            0.            0.            0.        ]
 [-146.04609099   10.0679469   119.62547043  -41.04828209]
 [-211.20420186   54.2661803   -16.84452631  -10.30840522]
 [-120.01503689   74.84312543   80.59650909 -113.78790944]]
det A : 0.0
A * W =
[[  0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00]
 [ -2.39901275e-02   1.00000000e+00   6.05737682e-13  -9.77706804e-12]
 [ -1.09284019e-02   7.53175300e-13   1.00000000e+00  -5.96855898e-13]
 [ -3.18218861e-02  -1.46656021e-11  -9.59232693e-13   1.00000000e+00]]

f:id:y-bros:20160925165520p:plain
Aの1行目を0にすると、xの1行目だけ復号できませんでした!
yの第1成分が0になってしまうようです。
0にする行番号に応じて復号できない場所がかわるよ!

ちなみに、列を0にすると...

なんかだめな感じになる。

$ python3 inverseMatrix.py
square loss = 773701.2808239068
square loss = 1701.2973314292854
square loss = 16977.430706964115
square loss = 31.125288095912666
square loss = 3831.4422570017064
~省略~
square loss = 0.18113840262486006
square loss = 0.14193590167027684
square loss = 0.024058306539580052
square loss = 0.004436463120030216
square loss = 0.012535315801299605
input x :
[ 0.4432559  -1.82602333  0.7054781   1.27180571]
output y :
[ 0.25960518 -1.82601637  0.6410365   1.33813611]
weight W :
[[ -1.85177696e-03  -7.55816010e-03   4.35443174e-03  -8.59710372e-04]
 [ -1.72270685e+01  -6.32767076e-02   1.81338919e+00  -6.25506408e-01]
 [ -6.53676268e-03  -6.53953602e-03   1.34949024e-02  -5.08275785e-03]
 [ -4.46643774e-03  -6.98725543e-03   1.41969346e-02   1.24244493e-03]]
prediction W^-1 :
[[ -9.17351926e+00  -6.10727329e-02   9.46331112e+00   1.61919724e+00]
 [ -1.83935249e+02   3.14462355e-04   3.63963471e+01   2.17792151e+01]
 [ -9.34438188e+01  -2.10145546e-02   3.51128715e+01   6.84061419e+01]
 [  3.54353209e-01   2.23445938e-02  -1.62516077e+02   1.51516944e+02]][f:id:y-bros:20160925171041p:plain]
target A :
[[  24.3736312     0.          -54.77096033   24.87286821]
 [-184.11179918    0.           36.72722192   21.66174867]
 [ -81.89297097    0.           13.01019509   76.40300608]
 [ -11.9215873     0.         -139.01470271  143.01039946]]
det A : 0.0
A * W =
[[  2.01797123e-01   1.63778028e-04  -2.79876967e-01   2.88336434e-01]
 [  4.10600199e-03   1.00001129e+00   1.45844274e-03  -1.47922214e-03]
 [ -2.74646311e-01   3.22269364e-05   9.03662442e-01   9.92030930e-02]
 [  2.92035196e-01  -5.32681876e-05   1.02407706e-01   8.94509729e-01]]

f:id:y-bros:20160925171041p:plain
loss値が微妙に下がりきらない。(図ではわかりづらいですが)
0のところで暗号化時に情報の欠落があるのだろうか。




ところでこれ、ただの最急降下法な気がするんですが...
機械学習とかホラふいて良いんですかね・・・?