データ科学特論 その 9: サポートベクターマシン 1 今回行う分析 1. 偽札の判別 2. スパムメールの判別 2 はじめに 現在, (考え方は古いが)新しい機械学習手法として人気があるサポートベクターマシン (SVM)を紹介する。3 層のニューラルネットワークと同等だが,未知のデータにもできる だけ良い性能を出そうとする,汎化を意識した学習を,マージン最大化として行う点が特徴 である。 なお,ここでは 2 クラスの判別のみを扱う。多クラスの判別や連続変数に対する回帰 SVM などもあるが,ここでは扱わない。 3 基本 教科書第 8 章図 8.1 を見てみよう.図 8.1 では,2 つの予測変数 X = (x, y) をそれぞれ縦 軸と横軸にとり,2 つのクラスを色分けして散布している.この図において 2 つのクラスを 判別(つまり x と y からどちらのクラスに属するかを判断する)するには,次のように直線 を考えれば良い。 y = ax + b ここで a, b は定数である。このような状況を,線形分離可能と呼ぶ。最小二乗法ではプロッ トに当てはまるような直線を求めていたが,ここでは 2 種類のプロットを分離するような直 線を考えることに注意。 1 しかし,図 8.2 については,線形分離は不可能である。上記のような直線では分離できな いからである。 このような場合に,次のように 2 次元を 3 次元に変換する関数を考えてみる。 √ Φ(x, y) = (x2 , y 2 , 2xy) (1) すると,図 8.3 を見ると,平面で分離することが出来る。このように予測変数の空間をいっ たん高次元にして,超平面(一つ低い次元の部分空間)で分離できるようにするのが,サポー トベクターマシン (SVM ) の基本的な考え方である。 3.1 超平面 まず,先の分離のための直線を超平面に一般化する. 先の直線は,行列を使って, x−y+b=0 ( ) ( ) x +b=0 ⇐⇒ 1 −1 y と表すことができる. これを m 次元の予測変数 X を変数 x1 , · · · , xm の縦行列で表し1 て一般化すると, wX + b = 0 で表すことができる.ただし w は m 次元の重みベクトルである. 判別のときには, wX + b ≥ 0 かどうかで判別すれば良いが, { f (x) = 1 if x ≥ 0, −1 otherwise, (2) というしきい値関数を使えば, f (wX + b) と表せば,この値が 1 か-1 かで判断できる。(2 層のニューラルネットワークに対応する。) 1 ここではデータアイテムの列ではなく個々の予測変数の列として表すことにする 2 3.2 マージン最大化 2 つのクラスを分離するある超平面をとったとき,その超平面から最も近い観測点(デー タアイテム)およびその反対方向に同じ距離だけ超平面を平行移動してできる領域を,マー ジンと呼ぶ2 . 分離する超平面の取り方は複数あるから,それに応じてマージンが変わってくる。汎化の 観点からは,なるべくマージンが大きい超平面を取る方が,未知のデータに対して余裕があ ることになる。 データセット x1 , · · · xn と超平面の最小ユークリッド距離は次で表現される3 。 min i=1,···,n |wxi + b| ||w|| この値を最大化する w を求めたいのだが,超平面 wX + b = 0 に 0 以外の定数をかけても 同じ平面を表すので,この縮尺をあわせるために, 分子を min |wxi + b| = 1 i=1,···,n と固定し,同時に,全ての xi について |wxi + b| ≥ 1 とする. (ここは実は不明。) このようにして以下の最適化問題を設定する。 目的関数: ||w|| を最小化 制約条件: 各 i = 1, · · · , m について |wxi + b| ≥ 1 3.3 双対問題とラグランジュ未定乗数法 これを,次のような最適化問題に変形する。 目的関数: ||w|| 2 2 を最小化 制約条件: 各 i = 1, · · · , m について |wxi + b| ≥ 1 2 3 これは厳密にはハードマージンであり,ソフトマージンでは領域内にデータアイテムがあることを許す。 √ ||w|| は w の長さ(ノルム)のことであり, wwt で定義される。(t は転置記号。) 3 これは線形の不等式を制約条件とする二次関数の最適化問題であり,このような問題は, ラグランジュの未定乗数法を使って解くことが出来る。以下の双対問題に対比して主問題と 呼ぶ。 ラグランジュの未定乗数法では,未定乗数と呼ばれる変数 αi ≥ 0 を用いて以下のようなラ グランジュ関数を最小化する。 ∑ 1 L(w, b, α) = ||w||2 − αi (|wxi + b| − 1) 2 n i=1 先ほどのしきい値関数 (2) を fi = f (wxi + b) と表すと,|wxi + b| = fi × (wxi + b) と表せる ので, n ∑ 1 L(w, b, α) = ||w||2 − αi {fi × (wxi + b) − 1} 2 i=1 ∑ ∑ ∑ 1 = ||w||2 − w αi fi xi t − b αi fi + αi 2 n n n i=1 i=1 i=1 (3) w または b を変化させて上記ラグランジュ関数が最小となるとき,その傾きは 0 となる。 よって偏微分により, ∑ ∑ ∂L =w− αi fi xi = 0 ⇐⇒ w = αi fi xi ∂w (4) ∂L ∑ = fi αi = 0 ∂b (5) n n i=1 i=1 n i=1 この結果を L(w, b, α) の定義式 (3) に代入すると, ∑ 1 L(α) = ||w||2 − wwt − 0 + αi 2 m i=1 = n ∑ i=1 = n ∑ i=1 1 αi − ||w||2 2 1 ∑∑ αi − αi αj fi fj xi xj t 2 n n i=1 j=1 4 (6) この式を目的関数とし,制約条件を先に出た αi ≥ 0 および n ∑ yi αi = 0 i=1 として最適化する問題を,主問題に対する双対問題と呼ぶ。 あとはこの双対問題を最急降下法などの一般の解法で α について解いていく。α が分かれ ば式 (4) より w は分かる。 b については,αi ̸= 0 であるデータアイテム xs により, |wxs t + b| = 1 を解くことにより求まる。上記のような αi ̸= 0 であるデータアイテムをサポートベクターと 呼ぶ。 3.4 双対問題の利点 双対問題に変換することには,次のような利点がある。 1. 主問題に比べて簡単な制約条件となる。 2. αi はほとんど 0 となり,計算においては α ̸= 0 であるデータアイテム,つまりサポー トベクターのみ考慮すれば良く,きわめて効率よく計算できる。 3. 式が x の内積 xxt で表されているため,後に説明するカーネルトリックの結果をその まま代入して使える。 非線形 SVM 4 ここまでは,予測変数 X の空間をいったん高次元にして線形分離可能でないものを分離 可能にするという考え方は用いてこなかった。それを用いる場合を,非線形 SVM と呼ぶが, ここではカーネルトリックという考え方が重要になる。 4.1 カーネルトリック 先の関数 Φ を用いて高次元化した予測変数に対する双対問題は次のようになる。 L(α) = n ∑ i=1 1 ∑∑ αi αj fi fj Φ(xi )Φ(xj )t 2 n αi − n i=1 j=1 5 ここで,Φ(xi )Φ(xj )t の計算は,何も考えなければそのまま計算しなければすることにな るが,Mercer(マーセル)の定理と呼ばれる条件を満たす場合は, Φ(xi )Φ(xj ) = k(xi , xj ) となるような関数が存在する。すなわち Φ により高次元化して内積を取るのではなく,直 接値を計算できる.ただし実際には高次元化の前に内積を取って計算する物が多い。それで もたくさんの計算を行う場合には効率化が図れる。このような関数 k をカーネルと呼び,Φ の代わりに k を使って計算することをカーネルトリックと呼ぶ。 カーネルが存在すれば,L(α) は次のように表現できる。 L(α) = n ∑ n i=1 j=1 i=1 4.2 1 ∑∑ αi αj yi yj k(xi , xj ) 2 n αi − さまざまなカーネル 代表的なカーネルは以下の 4 つである。 線形カーネル: k(x, y) = xyt 多項式カーネル: k(x, y) = (γxyt + δ)d RBF(ガウシアン) カーネル: シグモイドカーネル: k(x, y) = e−γ||x−y|| 2 k(x, y) = tanh(γ(xyt ) − δ),ただし tanh(x) = e2x − 1 e2x + 1 多項式カーネルのシンプルな例 (γ = 1, δ = 0, d = 2) だが,2 次元データアイテム (x, y) 自 身の内積を考えてみる。 ( ) ( ) x t 2 k(x, x) = (xx ) = { x y }2 = (x2 + y 2 )2 = x4 + y 4 + 2x2 y 2 y これは式 (1) で示した高次元化の関数と同様になる。なぜなら Φ(x)Φ(x)t = ( x2 y 2 2 x ) √ 2xy y 2 = x4 + y 4 + 2x2 y 2 √ 2xy 6 より上式と等しくなるからである。 上述のうち後ろ 3 つのカーネルには d,γ というパラメータが存在する。これらのパラメー タは識別期の精度に大きく影響するため,交差妥当性に配慮しつつ,チューニングする必要 がある。 4.3 ハードマージンとソフトマージン マージン内に少しくらいはデータアイテムの存在を許す場合をソフトマージンと呼ぶ。そ の場合は, 不等式制約を |wxi t + b| ≥ 1 − ζi , ζi ≥ 0 のように変形する。ζi はスラック変数と呼ばれ,対応するデータアイテム xi に対してマージ ンがその分だけ狭く,または逆方向に飛び出していることを示す。 これらをまとめて制約する定数 C を導入し,最適化問題を以下のように変更する。 目的関数: ||w|| 2 2 +C ∑n i=1 ζi 制約条件: fi (wxi + b) ≥ 1 − ζi , ζi ≥ 0, C > 0 ここまでに出たパラメータ d,γ,C をチューニングパラメータと呼ぶが,これらを経験的に 探索する手法をグリッドサーチと呼ぶ。そのための手法が R には用意されている。 R で SVM 5 5.1 偽札データの分析 データ読み込み > library(e1071) > お札<-read.csv("chap8/お札.csv", header=T, fileEncoding="sjis") 学習 > お札結果<- svm(真偽~.,data=お札, probability=T, kernel="radial", cross=1) 引数 kernel は RBF を示す radial となっている。線形カーネルなら linear, 多項式カー ネルなら polynomial,シグモイドカーネルなら sigmoid とすれば良い。 7 引数 probability=T はあとの predict でデータアイテムがクラスに所属する確率を出力 するために必要。 引数 cross は重交差妥当化のためのデータ分割数である。データ数が多ければ分割数を多 めに取ると良い。 予測 > お札予測値<-predict(お札結果,newdata=お札, probability=T, decision.values=T) 引数 decision.value=TRUE とすれば,f (xi ) の予測値も算出されるようになる.引数 probability=T はデータアイテムがクラスに所属する確率を出力することを指定している。 クロス表 > table(お札予測値, お札 [,1]) お札予測値 偽札 真札 偽札 100 1 真札 0 99 解の抽出 > お札結果 Call: svm(formula = 真偽 ~ ., data = お札, probability = T, kernel = "radial", cross = 1) Parameters: SVM-Type: SVM-Kernel: cost: gamma: C-classification radial 1 0.5 Number of Support Vectors: 19 cost はチューニングパラメータの C ,gamma は γ を表している。 Number of Support Vectors: 19 では,本モデルにおけるサポートベクターが 19 個で あると言っている。 8 サポートベクターの表示 > t(round(head(お札結果$SV,10),3)) 3 4 7 11 13 27 29 x1 -0.497 0.334 -0.081 -0.981 0.196 -0.012 0.472 x2 -0.246 -0.159 -0.246 -1.114 0.101 -0.246 -0.159 61 71 80 x1 -1.397 -0.497 -0.843 x2 -1.808 -2.329 -2.329 しきい値 b > round(お札結果$rho,3) [1] -0.232 γとC > (ガンマ<-お札結果$gamma) [1] 0.5 > (C<-お札結果$cost) [1] 1 予測値と所属確率 > 予測値<- round(attr(お札予測値, "decision.values"),3) > 予測確率<-round(attr(お札予測値,"probabilities"),3) > t(head(予測値,8));t(tail(予測値,8)) 1 2 3 4 5 6 7 偽札/真札 1.518 1.584 0.157 0.899 1.26 1.566 0.658 8 偽札/真札 1.454 193 194 195 196 197 198 偽札/真札 -1.752 -1.347 -1.621 -1.759 -1.26 -1.583 199 200 偽札/真札 -1.668 -1.512 > t(head(予測確率,8));t(tail(予測確率,8)) 1 2 3 4 5 6 7 8 偽札 0.993 0.995 0.598 0.946 0.984 0.994 0.89 0.991 真札 0.007 0.005 0.402 0.054 0.016 0.006 0.11 0.009 193 194 195 196 197 198 199 200 偽札 0 0.009 0.004 0 0.012 0.004 0.003 0.005 真札 1 0.991 0.996 1 0.988 0.996 0.997 0.995 attr は変数に含まれる特定の情報を抽出するための関数。ここではお札予測値という変 数の decision.values と probabilities を抽出している。attr が必要かどうかは str 関 数で調べると良い。 9 グリッドサーチ > チューニング<-tune.svm(真偽~.,data=お札,gamma=2^c(seq(-2,1,0.25)),cost=2^c(seq(-2,1,0.25))) γ と C それぞれに対して 2−2 から 21 の区間で指数に 0.25 刻みにグリッドを設けている。 これは gamma および cost という引数に,2 にベクトル seq(-2,1,-0.25) 乗することにより 指定している。 誤判別率の抽出 > 結果<-summary(チューニング) > str(結果 [[7]]) ’data.frame’: 169 obs. of 4 variables: $ gamma : num 0.25 0.297 0.354 0.42 0.5 ... $ cost : num 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 ... $ error : num 0.01 0.01 0.01 0.01 0.005 0.01 0.01 0.01 0.005 0.005 ... $ dispersion: num 0.0316 0.0316 0.0316 0.0316 0.0158 ... 3 次元プロット > library(scatterplot3d) > scatterplot3d(結果 [[7]][,1], 結果 [[7]][,2], 結果 [[7]][,3],xlab="gamma",ylab="C",zlab="error rate") 5.2 スパムの判別 データ読み込み > library(kernlab) > data(spam) > head(spam) make address all num3d our over remove internet 1 0.00 0.64 0.64 0 0.32 0.00 0.00 0.00 2 0.21 0.28 0.50 0 0.14 0.28 0.21 0.07 3 0.06 0.00 0.71 0 1.23 0.19 0.19 0.12 order mail receive will people report addresses free 1 0.00 0.00 0.00 0.64 0.00 0.00 0.00 0.32 2 0.00 0.94 0.21 0.79 0.65 0.21 0.14 0.14 3 0.64 0.25 0.38 0.45 0.12 0.00 1.75 0.06 business email you credit your font num000 money hp 1 0.00 1.29 1.93 0.00 0.96 0 0.00 0.00 0 2 0.07 0.28 3.47 0.00 1.59 0 0.43 0.43 0 3 0.06 1.03 1.36 0.32 0.51 0 1.16 0.06 0 10 ●● ● ●● ● ●● ● ●● ● ● ● ● ●● ● ● ● ● 0.005 0.006 ●● ● ● ● ● ●● ● ● ● ● ● ●● ● ● ● ● ● ● ●● ● ● ● ● ● ● ●● ● ● ● ● ● ● ● ● ● ●●●● ● ●● ● ● ●●● ● ●● ●● ●● ●● ●● ● ● ● ● ● ● ● ● ● ● ● ● ● ● ●● ● ● ● ● ● ● ● ● ● ● ● ● ● ●● ● ● ● ● ● ●● ● ● ●● ● ● ● ● 0.5 0.0 0.0 0.5 1.0 1.5 2.0 gamma 図 1: default 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 hpl george num650 lab labs telnet num857 data num415 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 num85 technology num1999 parts pm direct cs meeting 0 0 0.00 0 0 0.00 0 0 0 0 0.07 0 0 0.00 0 0 0 0 0.00 0 0 0.06 0 0 original project re edu table conference 0.00 0 0.00 0.00 0 0 0.00 0 0.00 0.00 0 0 0.12 0 0.06 0.06 0 0 charSemicolon charRoundbracket charSquarebracket 0.00 0.000 0 0.00 0.132 0 0.01 0.143 0 charExclamation charDollar charHash capitalAve 0.778 0.000 0.000 3.756 0.372 0.180 0.048 5.114 0.276 0.184 0.010 9.821 capitalLong capitalTotal type 61 278 spam 101 1028 spam 11 ● ● ● ● ● ● ● ● 1.0 1.5 2.0 C 0.008 0.007 error rate 0.009 0.010 ● ●● ● ● ● ● ●● ● ●● ●● ● ●● ●●●● ● ●●●● ● ● ● ● ● 3 485 2259 spam データの分割 データを分割するために,a というベクトルから b 個のデータ(要素番号)を無作為抽出 する sample(a,b) という関数を使う。 > index<-sample(1:dim(spam)[1],dim(spam)[1]/2) > 学習スパム<-spam[-index,] > テストスパム<-spam[index,] 学習 > > > > RBF モデル<-svm(type~.,data=学習スパム,kernel="radial") 多項式モデル<-svm(type~.,data=学習スパム,kernel="polynomial") シグモイドモデル<-svm(type~.,data=学習スパム,kernel="sigmoid") 線形モデル<-svm(type~.,data=学習スパム,kernel="linear") 予測 > > > > RBF 予測結果<-predict(RBF モデル, テストスパム [,-58]) 多項式予測結果<-predict(多項式モデル, テストスパム [,-58]) シグモイド予測結果<-predict(シグモイドモデル, テストスパム [,-58]) 線形予測結果<-predict(線形モデル, テストスパム [,-58]) クロス表 > table(RBF 予測結果, テストスパム$type) RBF 予測結果 nonspam spam nonspam 1342 113 spam 50 795 > table(多項式予測結果, テストスパム$type) 多項式予測結果 nonspam spam nonspam 1355 548 spam 37 360 > table(シグモイド予測結果, テストスパム$type) シグモイド予測結果 nonspam spam nonspam 1284 149 spam 108 759 12 > table(線形予測結果, テストスパム$type) 線形予測結果 nonspam spam nonspam 1330 103 spam 62 805 6 宿題 今回 R で行った内容を,自分の計算機環境で再現してください。 13
© Copyright 2024 ExpyDoc