2015年9月9日水曜日

Rubyでパーセプトロンを実装(改)

前回(→Rubyでパーセプトロンを実装)書いたやつは、特徴空間が1次元でクラス数が2であるものにしか対応していませんでしたが、それを改良して1次元以上2クラス以上のものにも対応できるようにしました。


ソースコード

前回はそのまま配列をベクトルとして扱ってましたが、今回はVectorを使ってみました。こっちの方が断然見やすくていい感じですね。

初期化のふたつ目の引数はパターンの次元数です。この数字と与えた学習パターンの次元数に齟齬があると、実行できません。あと、線形分離可能な学習パターンの集合を入力しないと実行が終わらないので、注意してください。

perceptron.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
require 'matrix'
 
class Perceptron
  RHO = 1
  attr_accessor :weight_vectors
 
  # 重みベクトルの初期化とか
  def initialize(learning_patterns, dimension)
    @learning_patterns = learning_patterns.map do |patterns|
      patterns.map { |pattern| Vector[1, *pattern] }
    end
 
    @class_num = @learning_patterns.size
    @weight_vectors = Array.new(@class_num) do |i|
      Vector[ *Array.new(dimension + 1) { i } ]
    end
  end
 
  # クラスiの識別関数 g_i(x)
  def discriminate(pattern, i)
    @weight_vectors[i].inner_product(pattern)
  end
 
  # 誤識別があった場合、重みベクトルを修正する
  def correct_errors
    @class_num.times do |i|
      @learning_patterns[i].each do |pattern|
 
        # jは識別結果
        j = @class_num.times.map do |c|
          { :class => c, :val => discriminate(pattern, c) }
        end.max_by { |e| e[:val] }[:class]
 
        # それが学習パターンのクラスと異なるなら修正
        if i != j
          @weight_vectors[i] += RHO * pattern
          @weight_vectors[j] -= RHO * pattern
        end
      end
    end
  end
 
  # 重みベクトルが更新されなくなるまで誤り訂正を繰り返す
  def learn
    correct_errors
    prev = nil
    while prev != @weight_vectors
      prev = Marshal.load(Marshal.dump(@weight_vectors))
      correct_errors
    end
  end
end


アルゴリズム

前回のものと流れは同じなので、詳しくはそっちを参照してほしいです。

複数クラスを想定しているので、前回と異なる点は誤識別したときの重みベクトルの修正の処理です。クラスiに属するパターンをクラスjであると誤識別した場合、
  • Wi' = Wi + ρ・X
  • Wj' = Wj - ρ・X
のようにクラスiとjの重みベクトルが修正されます。


例1 (特徴空間1次元 クラス数3)

今回、実行ファイルの例は別のところへ置いておきました。gnuplotのgemがインストールされていれば動きます。

perceptron.rb 自体は一応、特徴空間1次元以上クラス数2以上のものに対応できるようになったのですが、それによって求められた重みベクトルが正しいかどうかを確認するための図を出力する処理までは汎用性のあるものを書くのが難しそうだったので、とりあえずこんな感じでプロットするとよさそうみたいな例をそれぞれ書いてみました。

https://gist.github.com/seinosuke/eef6495c1b463b156d87#file-example01-rb
このリンク先の example01.rb を実行すると…
(出力されるのは図2のみです)

 

図1 各学習パターン
 
図2 求められた境界


例2 (特徴空間2次元 クラス数2)

先程のリンク先の example02.rb を実行すると…
(出力されるのは図4のみです)

 

図3 各学習パターン

図4 求められた境界


おわりに

どんな次元やクラス数にも対応したとは言っても、1クラスにつき1プロトタイプではやはり限界があり、次元やクラス数が大きくなってくると今回書いたものでも境界がうまく求められません。

もっと正確に識別できるように、次は線形分離不可能なやつに対応したり、誤差評価のやつとかを書いてみたいです。

参考文献
石井健一郎・前田英作・上田修功・村瀬洋 (1998) 『わかりやすいパターン認識』 オーム社

0 件のコメント:

コメントを投稿