2015年12月20日日曜日

Rubyで学習ベクトル量子化(改)

前回(http://syoshinsyakangeisagi.blogspot.com/2015/11/ruby.html)、
Rubyで学習ベクトル量子化をやってみたのですが今回はその改良版です。


ソースコード

lvq.rb だけ以下のように変更されました。

lvq.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class LVQ
  attr_accessor :log
  ALPHA = 0.005
  REP_MAX = 6
 
  def initialize(learning_patterns, dimension, threshold = 0.01)
    @log = []
    @dimension = dimension
    @threshold = threshold
    @class_num = learning_patterns.size
    @learning_patterns = learning_patterns.map do |patterns|
      patterns.map { |pattern| Vector[*pattern.map(&:to_f)] }
    end
    @class_size = @learning_patterns.map { |patterns| patterns.size }
    @error_prob = Array.new(@class_num) { 0.0 }
 
    # 各クラスのどれかを代表パターンの初期値に
    @representative_patterns = @class_num.times.map do |i|
      [ Marshal.load(Marshal.dump(@learning_patterns[i].sample)) ]
    end
  end
 
  def learn
    loop do
      5.times { correct_errors }
      cal_error
      break if @error_cnt.to_f / @class_size.inject(:+) < @threshold
 
      # 最も誤識別が多いクラスの代表ベクトルを追加する
      i = @error_prob.index(@error_prob.max)
      unless @representative_patterns[i].size >= REP_MAX
        @representative_patterns[i].tap do |r|
          r << Marshal.load(Marshal.dump(r.sample))
        end
      end
    end
  rescue Interrupt
    return
  end
 
  # 代表パターンを修正していく
  def correct_errors
    @learning_patterns.each_with_index do |patterns, i|
      patterns.each do |pattern|
        r_i, r_j = nearest_neighbor(pattern)
        if i == r_i
          @representative_patterns[r_i][r_j] +=
            ALPHA * (pattern - @representative_patterns[r_i][r_j])
        else
          @representative_patterns[r_i][r_j] -=
            ALPHA * (pattern - @representative_patterns[r_i][r_j])
        end
      end
    end
    @log << Marshal.load(Marshal.dump(@representative_patterns))
  end
 
  # 各クラスの誤識別率を計算する
  def cal_error
    @error_prob = Array.new(@class_num) { 0.0 }
    @error_cnt = 0
    @learning_patterns.each_with_index do |patterns, i|
      error_cnt = 0
      patterns.each do |pattern|
        r_i, r_j = nearest_neighbor(pattern)
        if i != r_i
          error_cnt += 1; @error_cnt += 1
        end
      end
      @error_prob[i] = error_cnt.to_f / @class_size[i].to_f
    end
  end
 
  # 最近傍の代表パターンは何クラスの何番目のものかを返す
  def nearest_neighbor(l_pattern)
    @representative_patterns.map.with_index do |patterns, i|
      patterns.map.with_index do |r_pattern, j|
        distance = @dimension.times.inject(0) do |sum, k|
          sum + (r_pattern[k] - l_pattern[k])**2
        end
        { :at => [i, j], :distance => distance }
      end
    end.flatten.min_by { |h| h[:distance] }[:at]
  end
end

変更点

各クラスの代表ベクトルは最初ひとつだけとし、誤識別が多かったクラスの代表ベクトルの数を増やしていくようにしました。

各代表ベクトルが更新されなくなってきたら、最も誤識別が多いクラスの代表ベクトルを追加するという処理を30行あたり周辺でやっています。unlessのところでは代表ベクトルが6個以上にならないようにしています(代表ベクトルがうじゃうじゃ増えるのを防ぐ)。

終了条件を設定するのが難しく、厳しくしすぎると終わらなかったりするので注意が必要です。


実行結果

 図1のような3クラス(色別になってる)の学習パターンに対して処理をします。

図1 学習パターン

そして lvq.rb だけ差し替えて前回の main.rb を実行。すると多分、代表ベクトルが増えつつ更新されていく様子がわかるgifが出力されると思います。
1
$ ruby main.rb


ちなみに、決定境界の変化も同時に描画させるようにすると図2のようなgifができます。前回の main.rbでgifをつくってるループのなかにボロノイ図を描画する行を挿入するだけですが、これをやると出力にものすごく時間がかかります。

図2 学習している様子のgif


おわりに

特に調べることもなく、なんとなく思いついて改良してみたのでこの方法の名前とかはわからないです。前回のやつは代表ベクトルの数が固定だったので、全てのクラスに同じ数の代表ベクトルがある状態でした。各クラスの分布の複雑さによって必要な代表ベクトルの数も変わってくるし可変にしたいなと思ってこんな感じになりました。

0 件のコメント:

コメントを投稿