Opt Technologies Magazine

オプトテクノロジーズ 公式Webマガジン

ベイズニューラルネットワークの変分推論

alt

業務で少し触れる機会のあったベイズニューラルネットワークについて紹介していこうと思います.

あいさつ

AIソリューション開発部の高野です.広告効果の予測モデルの開発を行っています. ここでは業務で少し触れる機会のあったベイズニューラルネットワークについて紹介していこうと思います.

記事の流れ

ベイズニューラルネットワークに触れる前に、その構成要素であるベイズ推論や深層学習について順を追って説明していきます. 具体的には、初めにベイズ推論について簡単に説明します.その後、線形回帰の発展として深層学習の確率モデルについて説明します.最後に全てを合わせたベイズニューラルネットワークとその推定方法(変分推論)について、説明しようと思います.

ベイズ推定

ベイズの哲学的な側面は説明せず、ここではベイズ推論について簡単に説明していきます。 ベイズ推論には以下のような2つの特徴があります.

  • 未知の量(パラメータ)を全て確率分布で表現する
  • ベイズの定理に従って、観測したデータを条件にパラメータの分布を更新する

観測したデータから情報を得ることで、パラメータの不確実性が減っていくことが期待できます。 具体的な更新には、以下の式(ベイズの定理)が用いられます.

$$ p(\theta|\mathcal{D}) = \frac{p(\mathcal{D}|\theta)p(\theta)}{p(\mathcal{D})} = \frac{p(\mathcal{D}|\theta)p(\theta)}{\int_{\Theta} p(\mathcal{D}|\theta) p(\theta) d\theta} \tag{1} $$

ここで、$\theta \in \Theta$ はパラメータ、$\mathcal{D}$ は観測したデータを表します.得られた量 $p(\theta|\mathcal{D})$ は事後分布と呼ばれます. $(1)$ 式は直感的には各パラメータ $\theta$ をそのパラメータ下でのデータ $\mathcal{D}$ の生じやすさ $p(\mathcal{D}|\theta)$ で重みづけて、$1$ に正規化して確率に戻す処理を行なっています. 情報をもとに不確実性(確率)の再割り当てを行なっていると考えられます.

事前分布

$(1)$ 式の $p(\theta)$ は事前分布と呼ばれ、あるパラメータに対して観測値を得る前にあらかじめもっている知識や不確実性を表現するのに用いられます.通常は事前の知識や過去の研究結果を元に分布が設定されますが、ベイズニューラルネットワークではそのように設定するのが難しいこともあり正規分布が用いられることが多いです.その平均と標準偏差の設定法には、通常のニューラルネットワークにおける重みの初期化方法(He initialization 等)に従って指定する方法や、ガウス分布ではなくラプラス分布を用いることでスパース性を持たせるといったこともできます.

尤度

$(1)$ 式の $p(\mathcal{D}|\theta)$ は尤度と呼ばれる量で、ある特定のパラメータ $\theta$ による観測値の得られやすさを表す量であり、パラメータの関数と捉えることができます.ベイズニューラルネットワークでは、ニューラルネットワークの構造を変更すると、この部分に反映されます.

周辺尤度(エビデンス)

$(1)$ 式の $p(\mathcal{D})$ は周辺尤度やエビデンスと呼ばれ、事前分布と尤度を用いて以下のように計算されます.

$$ p(\mathcal{D}) = \int_{\Theta} p(\mathcal{D}|\theta) p(\theta) d\theta $$

この積分計算は一般的には解析的に解けないので、この計算を迂回した形で事後分布 $p(\theta|\mathcal{D})$ を求められるかが重要になります.迂回した形で求める手法には主に2つに分けられます.

  • サンプリングを用いた漸近的に正確な方法(どれだけサンプリングすれば良いかはわからない)
  • 変分推論を用いた高速な方法(漸近的にも正確性は保証されていない)

ベイズニューラルネットワークではパラメータの数が膨大なため、計算効率をあげるために後者の変分推論が用いられることが多いです.(漸近的に正確な)サンプリングを用いて事後分布を推定することでベイズニューラルネットワークの性能を検証した論文も存在します.[2104.14421] What Are Bayesian Neural Network Posteriors Really Like?

深層学習

ここでは、まず線形回帰に関して簡単に説明してそれを徐々にニューラルネットワークへと拡張していきます. 最後に、深層学習と呼ばれるモデルについて少しだけ触れます.

線形回帰

ここでは、線形回帰モデルについて簡単に説明していきます.線形回帰モデルは出力 $y \in \mathbb{R}$ と入力 $\mathbf{x} \in \mathbb{R}^{D}$ の関係をモデル化する際に用いられ、以下の式で表現されます.

$$ p(y|\mathbf{\theta}, \mathbf{x}) = \mathcal{N}(y|w_0 + \mathbf{w}^T \mathbf{x}, \sigma^{2}) $$

$\mathbf{\theta} = (w_0, \mathbf{w}, \sigma^{2})$ であり、これがモデルのパラメータになります. 表記を簡便にするために、$\mathbf{x} = [1, x_1, \ldots, x_D]^T$ とし、$w_0$ を $\mathbf{w}$ に吸収する形で表現されることが多いため、ここからはそのように扱うことにします. そのため、上記の式は以下のようになります.

$$ p(y|\mathbf{\theta}, \mathbf{x}) = \mathcal{N}(y|\mathbf{w}^T \mathbf{x}, \sigma^{2}) \tag{2} $$

この単純な線形回帰モデルは、入力に対して線形の関係しか表現できないという重大な欠点があります.この欠点は、非線形変換 $\mathbf{\phi}(\cdot): \mathbb{R}^{D} \to \mathbb{R}^{N}$ を用いて入力 $\mathbf{x}$ を変換することで解消することができます.この変換により非線形な関係を表現できるようになるため、ある程度複雑な入出力関係を捉えられるようになります. この変換を施したモデルは以下のようになります.

$$ p(y|\mathbf{w}, \mathbf{x}) = \mathcal{N}(y|\mathbf{w}^T \mathbf{\phi}(\mathbf{x}), \sigma^{2}) $$

$(2)$ 式のモデルとは、$\mathbf{w}$ の次元が異なる点に注意してください. このモデルは、$\mathbf{\phi}(\cdot)$ の設計次第で多くの関数を表現できます.加えて、パラメータに関しては線形であるから推定も容易であることが知られています.しかしながら、このモデルにも欠点は存在しています.それは $\mathbf{\phi}(\cdot)$ をあらかじめ設計しなければならない点にあります.これはデータ分析における特徴量エンジニアリングに対応する作業となっており、ドメイン知識や問題理解が要求される中々に難しい作業です.上述の欠点の解消のため、上記の変換 $\mathbf{\phi}(\cdot)$ にパラメータをもたせることで、データから変換を学習できるようにすることができます. このようにすることで、その問題に合った変換を学習できるようになります.これがニューラルネットワークです.

ニューラルネットワーク

出力の多次元化

上記の線形回帰をニューラルネットワークへと拡張する前に、出力も多次元ベクトル $\mathbf{y} \in \mathbb{R}^{H_L}$ の場合のモデルへと拡張します. このモデルは、多変量線形回帰と呼ばれます.

$$ p(\mathbf{y}|\mathbf{W}_1, \mathbf{x}) = \mathcal{N}_{H_L}(\mathbf{y}|\mathbf{W}_1 \mathbf{x}, \Sigma) $$

ここで、入力は $\mathbf{x} \in \mathbb{R}^{H_0}$、パラメータは $\mathbf{W}_1 \in \mathbb{R}^{H_L \times H_0}$ であり、$\Sigma$ は $H_L \times H_L$ の分散共分散行列です.

変換の学習(二層のニューラルネットワーク)

変換 $\mathbf{\phi}(\cdot)$ にパラメータをもたせて自動的に変換を学習させるように拡張していきますが、ニューラルネットワークでは以下のように線形回帰と似た形でパラメータを持たせます. その上で非線形変換 $\sigma_1: \mathbb{R} \to \mathbb{R}$ をベクトルの要素毎に行います.

$$ \mathbf{\phi}(\mathbf{x}|\mathbf{W}_2) = \mathbf{\sigma}_1 (\mathbf{W}_2 \mathbf{x}) $$

このとき、モデルは

$$ p(\mathbf{y}|\mathbf{W}_1, \mathbf{W}_2, \mathbf{x}) = \mathcal{N}_{H_L}(\mathbf{y}|\mathbf{W}_1 \sigma_1(\mathbf{W}_2 \mathbf{x}), \Sigma) \tag{3} $$

と表現できます.ここで、$\mathbf{W}_1 \in \mathbb{R}^{H_L \times H_1}, \mathbf{W}_2 \in \mathbb{R}^{H_1 \times H_0}$ です.上記の $\mathbf{W}_2$ をデータから学習することで非線形な関数を表現できるようになります. これは二層のニューラルネットワークですが、さらに非線形変換を繰り返して層を深めていくことで、より深いニューラルネットワークが構成できます.

非線形変換の理由は、$\sigma_1$ が恒等関数の場合を考えると

$$ \mathbf{W}_1 \sigma_1(\mathbf{W}_2 \mathbf{x}) = (\mathbf{W}_1 \mathbf{W}_2)\mathbf{x} $$

であるので、$\mathbf{W} = \mathbf{W}_1 \mathbf{W}_2$ とすることで

$$ p(\mathbf{y}|\mathbf{W}_1, \mathbf{W}_2, \mathbf{x}) = \mathcal{N}_{H_L}(\mathbf{y}|\mathbf{W} \mathbf{x}, \Sigma) $$

となり、多変量線形回帰と同様になってしまいパラメータを持たせた意味がなくなるからで、層を深めるためには必須になります.

$(3)$ 式が、最も簡単なニューラルネットワークであり、層数で数えると二層ですが、中間層の次元($\mathbb{R}^{H_1}$)を無限大にするとあらゆる関数を表現できることが知られています(万能近似定理). しかしながら、そのような関数を学習できるかや必要なデータ数、その近似精度はまた別問題なので注意が必要です.

一般のニューラルネットワーク

一般には、ニューラルネットワークは複数の関数の合成関数として捉えることができるため、以下のように整理することができます. 層数 $L$ の場合を考えて、出力は $\mathbf{y} \in \mathbb{R}^{H_L}$、入力は $\mathbf{x} \in \mathbb{R}^{H_0}$ とします.また、各層 $l$ のベクトル値関数を $\boldsymbol{f}_l: \mathbb{R}^{H_{l-1}} \to \mathbb{R}^{H_l}$ としたとき、以下のように表現できます.

$$ p(\mathbf{y}|\mathbf{W}, \mathbf{x}) = \mathcal{N}_{H_L}(\mathbf{y}|\boldsymbol{f}(\mathbf{x}; \mathbf{W}), \Sigma) $$

$$ \boldsymbol{f}(\mathbf{x}|\mathbf{W}) = (\boldsymbol{f}_L \circ \boldsymbol{f}_{L-1} \circ \cdots \circ \boldsymbol{f}_1)(\mathbf{x}) $$

例えば、$(3)$ 式は以下のように表せます.

$$ \begin{aligned} \boldsymbol{f}_1(\mathbf{x}) &= \sigma_1(\mathbf{W}_2 \mathbf{x}) \\ \mathbf{h}_1 &= \boldsymbol{f}_1(\mathbf{x}) \\ \boldsymbol{f}_2(\mathbf{h}_1) &= \mathbf{W}_1 \mathbf{h}_1 \\ \mathbf{y} &= \boldsymbol{f}_2(\mathbf{h}_1) \end{aligned} $$

深層学習

上記で説明した線形回帰の拡張としてのニューラルネットワークは深層学習の一部、マルチレイヤーパーセプトロンやフィードフォワードネットワークと呼ばれており、この他にも様々な構造があります.深層学習は、微分可能な合成関数を様々な形で組み合わせたものと言えます.DAGで表現できて微分可能であれば、現在の深層学習ライブラリで学習が可能なため、CNN、Transformerといった様々な構造のモデルが実際に使われています.

以上で、ニューラルネットワークと深層学習がどのようなものかがわかったと思うので、次はパラメータの推定手法とそのアルゴリズムについて説明していきます.

ベイズニューラルネットワーク(BNN)

ここでは、上記で説明したニューラルネットワークをベイズ化することを考えます. まず、最初に説明したように未知の量(ここでは、パラメータ$\mathbf{W}$)に(事前)分布を設定します.そして、ベイズの定理を用いてその分布を更新していきます. 基本的には上記の通りベイズ推定を行うだけですが、ニューラルネットワークの事後分布は解析的に計算できないため、周辺尤度の部分で説明したように近似推論を用いて事後分布を計算する必要があります. 上で話したように、BNNの近似推論は計算効率のため変分推論が用いられることが多いため、今回は変分推論のみを説明します.

変分推論

まず簡単に変分推論をについて説明します.簡単にいうと、パラメータを持った扱い易い分布を事後分布に近づけていくことで得られた分布を近似解とする方法です.近さの判定には、KLダイバージェンスが用いられることが多く、パラメータを持った扱い易い分布は正規分布が選ばれることが多いです.もちろん他の指標や分布を用いることも可能です.気になる方は $\alpha$-Divergence や normalizing flow を検索してみてください. KLダイバージェンスは、以下の式で表される量です.

$$ \mathcal{KL}[q|p] = \int \log \frac{q(x)}{p(x)} q(x) dx $$

KLダイバージェンスの主な性質としては、

  1. $\mathcal{KL}[q|p] \ge 0$
  2. $\mathcal{KL}[q|p] \neq \mathcal{KL}[p|q]$

があります.上記2の性質により、$\mathcal{KL}[q|p]$ と $\mathcal{KL}[p|q]$ のどちらの式を用いて近さを判定するかによって得られる結果が異なります. $q$ を近似分布、$p$ を事後分布とした時に、$\mathcal{KL}[q|p]$ は reverse KL と呼ばれ、$\mathcal{KL}[p|q]$ は forward KL と呼ばれます. 変分推論では、reverse KL が用いられます.それぞれの式によって得られる結果には特徴が存在しており、reverse KL は事後分布のモード(最頻値)を捉えるような解が得られるため、分散を過小評価する傾向があります. これまでを議論を式で表すと、データセットを $\mathcal{D}$ で表すと

$$ q^{*}(\mathbf{w}|\mathbf{\eta}) = \arg\,min_{q \in Q} \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w}|\mathcal{D})] $$

のように簡潔に表すことができ、汎関数の最小化問題として定式化されます.そのため、変分という名前がついています.ここで、$\mathcal{\eta}$ は近似分布のパラメータで、最適化の対象になります. ここから、上記の最適化問題を解いていきます.まず、

$$ \begin{align} \mathcal{KL}[q|p] &= \int \log \frac{q(\mathbf{w}|\mathbf{\eta})}{p(\mathbf{w}|\mathcal{D})} q(\mathbf{w}|\mathbf{\eta}) d\mathbf{w} \\ &= \int \log \frac{q(\mathbf{w}|\mathbf{\eta})}{\frac{p(\mathbf{w}, \mathcal{D})}{p(\mathcal{D})}} q(\mathbf{w}|\mathbf{\eta}) d\mathbf{w} \\ &= - \int \log \frac{p(\mathbf{w}, \mathcal{D})}{q(\mathbf{w}|\mathbf{\eta})} q(\mathbf{w}|\mathbf{\eta}) d\mathbf{w} + \int \log p(\mathcal{D}) q(\mathbf{w}|\mathbf{\eta}) d\mathbf{w} \\ &= - E_{q} \left[\log \frac{p(\mathbf{w}, \mathcal{D})}{q(\mathbf{w}|\mathbf{\eta})} \right] + \log p(\mathcal{D}) \\ \end{align} $$

整理すると、

$$ \log {p(\mathcal{D})} = E_{q} \left[\log \frac{p(\mathbf{w}, \mathcal{D})}{q(\mathbf{w}|\mathbf{\eta})} \right] + \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w}|\mathcal{D})] $$

となります.この時、$\mathcal{KL}[q|p] \ge 0$ であることから、$E_{q} \left[\log \frac{p(\mathbf{w}, \mathcal{D})}{q(\mathbf{w}|\mathbf{\eta})} \right]$ はエビデンス(周辺尤度)の対数の下限です.そのため、Evidence Lower BOund(ELBO) と呼ばれています.

$$ \mathcal{L}[q] := E_{q} \left[\log \frac{p(\mathbf{w}, \mathcal{D})}{q(\mathbf{w}|\mathbf{\eta})} \right] $$

さらに、$\mathcal{D}$ が与えられたとき、$\log {p(\mathcal{D})}$ は固定されるので、$\mathcal{KL}[q|p]$ の最小化は、$\mathcal{L}[q]$ の最大化と等価です.$\mathcal{KL}[q|p]$ の最小化は、$\mathcal{L}[q]$ の最大化が等価であることがわかったことの大きな利点は、$\mathcal{KL}$ には、解析的に計算できない事後分布 $p(\mathbf{w}|\mathcal{D})$ が存在していたが、$\mathcal{L}$ では、$p(\mathbf{w}, \mathcal{D})$ となっている点です.$\log p(\mathbf{w}, \mathcal{D})$ は $\log p(\mathcal{D}|\mathbf{w})$ と $\log p(\mathbf{w})$ に分けられるため、計算が難しい $\log p(\mathcal{D})$ が消えさり、計算が容易になります. ELBO を式変形すると、

$$ \begin{align} \mathcal{L}[q] &= E_{q} \left[\log \frac{p(\mathbf{w}, \mathcal{D})}{q(\mathbf{w}|\mathbf{\eta})} \right] \\ &= E_{q} \left[\log p(\mathbf{w}, \mathcal{D}) \right] - E_{q} \left[\log q(\mathbf{w}|\mathbf{\eta}) \right] \\ &= E_{q} \left[\log p(\mathcal{D}|\mathbf{w}) \right] + E_{q} \left[\log p(\mathbf{w}) \right] - E_{q} \left[\log q(\mathbf{w}|\mathbf{\eta}) \right] \\ &= E_{q} \left[\log p(\mathcal{D}|\mathbf{w}) \right] - \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w})] \\ \end{align} $$

となります.第一項は対数尤度の期待値であり、第二項は近似分布と事前分布の負のKLダイバージェンスです.そのため、ELBOの最大化の際に、第一項により近似分布はよりデータを説明するように最適化され、第二項は近似分布を事前分布から離れすぎないように正則化する働きをすることになる.ここで、後で使うために以下のように定義をしておきます.

$$ \begin{align} \mathcal{L}_{data}[q] &:= E_{q} \left[\log p(\mathcal{D}|\mathbf{w}) \right] \\ \mathcal{L}_{prior}[q] &:= \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w})] \end{align} $$

ニューラルネットワークの変分推論

上記の変分推論をニューラルネットワークに適用していきます.その際、近似分布には以下の正規分布を用います.

$$ q(\mathbf{w}| \mathbf{\eta}) = \prod_{i} \mathcal{N}(w_i|\mu_i, \sigma_i) $$

ここから、勾配上昇法を用いて $\mathbf{\eta}$ の最適化を行うために $\mathcal{L}[q]$ の $\mathbf{\eta}$ に対する勾配を計算していきます.$\mathcal{L}_{prior}[q]$ は事前分布と近似分布の組み合わせによっては、解析的に計算することが可能です(例えば、両方が正規分布の場合は計算可能です).そのため、計算で得られた結果を微分すれば問題ありません.しかしながら、解析的に計算が難しい場合は、$\mathcal{L}_{data}[q]$ の勾配近似と同様な方法を用いることで計算可能になります. $\mathcal{L}_{data}[q]$ の式を再喝すると、

$$ E_{q} \left[\log p(\mathcal{D}|\mathbf{w}) \right] = \int \log p(\mathcal{D}|\mathbf{w}) q(\mathbf{w}| \mathbf{\eta}) d\mathbf{w} $$

です.ほとんどの場合、この積分を解析的に計算するのは難しいため、計算で得られた結果の勾配を求めることができません.そこで、積分を解析的に計算することは諦めて、近似をしていきます.代表的な方法に分布からのサンプルを用いて期待値(積分)を近似するモンテカルロ積分があるので適用してみると、以下のようになります.

$$ \begin{align} \mathbf{w}_{i} &\sim q(\mathbf{w}|\mathbf{\eta}) \quad (i = 1, \ldots, N) \\ E_{q} \left[\log p(\mathcal{D}|\mathbf{w}) \right] &\approx \frac{1}{N} \sum_{i} \log p(\mathcal{D}|\mathbf{w}_{i}) \end{align} $$

これにより積分の値は近似できたのですが、$q(\mathbf{w}|\mathbf{\eta})$ が消えてしまうので、$\mathbf{\eta}$ に対しての勾配が計算できないため、勾配上昇法を用いることができなくなってしまいます. そのため、考え方を変えて、ここからは積分を計算してから勾配を計算するのではなく、勾配を直接計算・近似することを目指します.

勾配近似

$\mathcal{L}_{data}[q]$ の勾配は以下のように書けます.この勾配を直接近似することを考えます.

$$ \nabla_{\mathbf{\eta}} \mathcal{L}_{data}[q] = \nabla_{\mathbf{\eta}} E_{q} \left[\log p(\mathcal{D}|\mathbf{w}) \right] $$

まず、適当な正則条件の元で、積分と微分が入れ替えられることを用いると、

$$ \nabla_{\mathbf{\eta}} E_{q} \left[\log p(\mathcal{D}|\mathbf{w}) \right] = \int \log p(\mathcal{D}|\mathbf{w}) \nabla_{\mathbf{\eta}} q(\mathbf{w}| \mathbf{\eta}) d\mathbf{w} $$

と変形できます.ここから、以下の対数微分の関係式を用いて変形します.

$$ \nabla_{\mathbf{\eta}} \log q(\mathbf{w}|\mathbf{\eta}) = \frac{\nabla_{\mathbf{\eta}}q(\mathbf{w}|\mathbf{\eta})}{q(\mathbf{w}|\mathbf{\eta})} $$

変形した結果は以下のようになります.

$$ \int \log p(\mathcal{D}|\mathbf{w}) \nabla_{\mathbf{\eta}} q(\mathbf{w}| \mathbf{\eta}) d\mathbf{w} = \int \log p(\mathcal{D}|\mathbf{w}) \nabla_{\mathbf{\eta}} \log q(\mathbf{w}|\mathbf{\eta}) q(\mathbf{w}| \mathbf{\eta}) d\mathbf{w} $$

この式は、$E_{q}[\log p(\mathcal{D}|\mathbf{w}) \nabla_{\mathbf{\eta}} \log q(\mathbf{w}|\mathbf{\eta})]$ と期待値の形で表すことができ、モンテカルロ積分で近似を行えます. この手法は log derivative trick(score function estimater) と呼ばれており、勾配近似の代表的な手法の一つです. この手法は非常に汎用性が高いが、分散が非常に大きくなるという問題点があるため、制御変量法といった分散減少法を用いるか、reparametrization trick(pathwise gradient estimater)が使える場合はそちらが用いられます.BNNでは、reparametrization trick が使用可能なため、そちらが用いられます.

ここから、reparametrization trick を用いた変分推論に関して説明していきます. reparametrization trick では、正規分布

$$ w_i \sim \mathcal{N}(w|\mu_i, \sigma_i) $$

が、以下のように変換(再パラメータ化)できることに注目します.

$$ \begin{align} \epsilon_i &\sim \mathcal{N}(\epsilon|0, 1) \quad (i = 1, \ldots, N) \\ w_i &= \mu_i + \sigma_i \cdot \epsilon_i \end{align} $$

上記の関係式と、変数変換の前後で期待値は変わらない(Law of the unconscious statistician)ことを用いると、

$$ \nabla_{\mathbf{\eta}} E_{q(\mathbf{\eta})} \left[\log p(\mathcal{D}|\mathbf{w}) \right] = \nabla_{\mathbf{\mu}, \mathbf{\sigma}} E_{q(\mathbf{\epsilon})}\left[\log p(\mathcal{D}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}) \right] $$

と式変形ができます.ここで、$\odot$ はベクトルの要素毎の積を表し、$\mathbf{\eta} = (\mathbf{\mu}, \mathbf{\sigma})$ であることを用いました.さらに、適当な正則条件の元で、積分と微分が入れ替えられることを用いると、

$$ \nabla_{\mathbf{\mu}, \mathbf{\sigma}} E_{q(\mathbf{\epsilon})}\left[\log p(\mathcal{D}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}) \right] = E_{q(\mathbf{\epsilon})}\left[\nabla_{\mathbf{\mu}, \mathbf{\sigma}} \log p(\mathcal{D}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}) \right] $$

となります.そのため、モンテカルロ積分を用いることが可能になり、$\mathbf{w}$ の次元を $d$ としたとき、以下のように近似できます.

$$ \begin{align} \mathbf{\epsilon}_{i} &\sim \mathcal{N}_{d}(\mathbf{\epsilon}|0, I) \quad (i = 1, \ldots, N) \\ \nabla_{\mathbf{\eta}} \mathcal{L}_{data}[q] &\approx \frac{1}{N} \sum_{ i} \nabla_{\mathbf{\mu}, \mathbf{\sigma}} \log p(\mathcal{D}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}_{i}) \end{align} \tag{4} $$

上記の近似勾配と勾配上昇法を用いて、$\mathcal{L}[q]$ を最大化することが可能となりました.しかしながら、実は1つ問題点があります.$\sigma$ のとり得る値の範囲が $(0, \infty)$ であるということです.最大化の過程に $\sigma$ が負になることを防ぐため、以下のような変換を要素毎に行い $\rho$ を最適化することで、その問題を回避します.

$$ \sigma = \log (1 + \exp{\rho}) $$

ここまでで、変分推論の概要は説明できたので、より具体的な問題設定で上記の $(4)$ 式を計算していきます.

教師あり学習への適用

教師あり学習の問題設定として、$\mathcal{D} = \{ (\mathbf{x}, y)_{1}, (\mathbf{x}, y)_{2}, \ldots, (\mathbf{x}, y)_{N_{\mathcal{D}}} \}$ であり、$\mathbf{x}_{j}$ は所与で $\textit{i.i.d}$ であるケースを考えます.このとき、尤度は

$$ \log p(\mathcal{D}|\mathbf{w}) = \sum_{j} \log p(y_{j}|\mathbf{w}, \mathbf{x}_{j}) $$

のように分解できます.よって、ELBO は

$$ \mathcal{L}[q] = E_{q} \left[\sum_{j}^{N_{\mathcal{D}}} \log p(y_{j}|\mathbf{w}, \mathbf{x}_{j}) \right] - \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w})] $$

となります.さらに、$(4)$ 式は

$$ \begin{align} \mathbf{\epsilon}_{i} &\sim \mathcal{N}_{d}(\mathbf{\epsilon}|0, I) \quad (i = 1, \ldots, N) \\ \nabla_{\mathbf{\eta}} \mathcal{L}_{data}[q] &\approx \frac{1}{N} \sum_{i}^{N} \sum_{j}^{N_{\mathcal{D}}} \nabla_{\mathbf{\mu}, \mathbf{\sigma}} \log p(y_{j}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}_{i}, \mathbf{x}_{j}) \end{align} $$

と書けるため、求めたい勾配は、以下のようになります.

$$ \begin{align} \mathbf{\epsilon}_{i} &\sim \mathcal{N}_{d}(\mathbf{\epsilon}|0, I) \quad (i = 1, \ldots, N) \\ \nabla_{\mathbf{\eta}} \mathcal{L}[q] &\approx \frac{1}{N} \sum_{i}^{N} \sum_{j}^{N_{\mathcal{D}}} \nabla_{\mathbf{\mu}, \mathbf{\sigma}} \log p(y_{j}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}_{i}, \mathbf{x}_{j}) - \nabla_{\mathbf{\eta}} \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w})] \end{align} $$

上記の $\nabla_{\mathbf{\mu}, \mathbf{\sigma}} \log p(y_{j}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}_{i}, \mathbf{x}_{j})$ は自動微分で計算できます.

ミニバッチ勾配上昇法

学習に多くの時間がかかる場合に、通常のニューラルネットワークの学習と同様にミニバッチを用いた勾配上昇(降下)法を適用することもできます.学習データ $\mathcal{D}$ から $N_\mathcal{B}$ 個のサンプルをランダムサンプリングしたデータセットを $\mathcal{B}$ とした場合、

$$ \mathcal{L}[q] \approx \frac{N_{\mathcal{D}}}{N_\mathcal{B}} E_{q} \left[\log p(\mathcal{B}|\mathbf{w}) \right] - \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w})] $$

となるので、$\textit{i.i.d}$ であることより、

$$ \mathcal{L}[q] \approx \frac{N_{\mathcal{D}}}{N_\mathcal{B}} E_{q} \left[\sum_{j}^{N_\mathcal{B}} \log p(y_{j}|\mathbf{w}, \mathbf{x}_{j}) \right] - \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w})] $$

となります.よって、モンテカルロ積分を用いて以下のように近似できます.

$$ \begin{align} \mathbf{\epsilon}_{i} &\sim \mathcal{N}_{d}(\mathbf{\epsilon}|0, I) \quad (i = 1, \ldots, N) \\ \nabla_{\mathbf{\eta}} \mathcal{L}[q] &\approx \frac{N_{\mathcal{D}}}{N_\mathcal{B}}\left[ \frac{1}{N} \sum_{i}^{N} \sum_{j}^{N_\mathcal{B}} \nabla_{\mathbf{\mu}, \mathbf{\sigma}} \log p(y_{j}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}_{i}, \mathbf{x}_{j}) \right] - \nabla_{\mathbf{\eta}} \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w})] \end{align} $$

となります.通常、モンテカルロ積分のサンプル数は $N = 1$ で計算されることが多く以下のように簡略化できます.

$$ \begin{align} \mathbf{\epsilon} &\sim \mathcal{N}_{d}(\mathbf{\epsilon}|0, I) \quad (i = 1, \ldots, N) \\ \nabla_{\mathbf{\eta}} \mathcal{L}[q] &\approx \frac{N_\mathcal{D}}{N_\mathcal{B}} \sum_{j}^{N_\mathcal{B}} \nabla_{\mathbf{\mu}, \mathbf{\sigma}} \log p(y_{j}|\mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}, \mathbf{x}_{j}) - \nabla_{\mathbf{\eta}} \mathcal{KL}[q(\mathbf{w}|\mathbf{\eta})|p(\mathbf{w})] \end{align} $$

上記の勾配を自動微分を用いて計算することで、ミニバッチを用いてBNNの学習が行えます.

予測分布

最後に、機械学習においての関心事となる予測分布に関して導出しておきます. 新たに得られたデータを $\mathbf{x}^{*}$ としたとき、予測結果 $y^{*}$ の分布 $p(y^{*}|\mathbf{x}^{*})$ は以下のように表せます.

$$ p(y^{*}|\mathbf{x}^{*}) = \int p(y^{*}|\mathbf{x}^{*}, \mathbf{w}) p(\mathbf{w}| \mathcal{D}) d\mathbf{w} $$

上記の式は、操作としては $\mathbf{w}$ を周辺化を行うことで消去しているだけだが、次のように解釈することができます.パラメータ $\mathbf{w}$ がとりうる値の範囲で個々の予測分布を $p(y^{*}|\mathbf{x}^{*}, \mathbf{w})$ として、データから推論したそのパラメータである可能性 $p(\mathbf{w}|\mathcal{D})$ で重みづけたものと解釈できます.予測分布 $p(y^{*}|\mathbf{x}^{*})$ もまた確率分布であり、上記のように事後分布の不確実性を考慮に入れた予測ができるというメリットがあります. 実際にBNNに適用する際には、上記の積分を解くことも難しいのでサンプリングを用いて近似することが多いです.事後分布 $p(\mathbf{w}| \mathcal{D})$ の代わりに最適化された近似分布 $q^{*}(\mathbf{w}|\mathbf{\eta})$ を用いて上記の予測分布を近似すると、

$$ p(y^{*}|\mathbf{x}^{*}) \approx \int p(y^{*}|\mathbf{x}^{*}, \mathbf{w}) q^{*}(\mathbf{w}|\mathbf{\eta}) d\mathbf{w} $$

となります.このとき、$p(y^{*}|\mathbf{x}^{*})$ のサンプリングは以下のように行えます.

$$ \begin{align} \mathbf{w}_{i} &\sim q^{*}(\mathbf{w}|\mathbf{\eta}) \\ y^{*}_{i} &\sim p(y^{*}|\mathbf{x}^{*}, \mathbf{w}_{i}) \end{align} $$

上記で得られたサンプルで興味のある量(平均値、信用区間)を計算することができます.

参考文献

  1. Kevin P. Murphy, Probabilistic Machine Learning: An Introduction
  2. Kevin P. Murphy, Probabilistic Machine Learning: Advanced Topics
  3. C.M.Bishop, Pattern Recogition and Machine Learning
  4. S.Mohamed et al., Monte Carlo Gradient Estimation in Machine Learning
  5. C.Blundell et al., Weight Uncertainty in Neural Networks
  6. Yarin Gal, Uncertainty in Deep Learning
  7. 須山敦志, ベイズ深層学習
  8. 須山敦志, ベイズ推論による機械学習入門
  9. 柳井啓司 他, 深層学習

まとめ

ベイズニューラルネットワークの確率モデルとその推論方法(変分推論)に関して紹介しました.実装やアルゴリズムについては紹介しませんでしたが、ディープラーニングフレームワークの自動微分を活かした形で実装することが可能なライブラリがあるので、そちらを用いると実装も簡単です.Pytorch の場合は Pyro、Tensorflow の場合は Tensorflow Probability がありますので、実装の際はそちらのドキュメントを参考にしてください.

Opt Technologies ではエンジニアを募集中です。カジュアル面談も可能ですので、下記リンク先よりお気軽にご応募ください。