一日坊主

雰囲気でやっている

PRML上巻 P6

昨日の多項式フィッティングのコードは流石に汚すぎたので,ライブラリを使うことにする.

numpy に numpy.polyfit というそのものズバリな機能があるようだ.

ドキュメントによると v1.4 以降は非推奨らしいが,簡単なので,これを使って再度図(1.4)を作成してみる.

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
# t = sin(2 \pi x)
x1 = np.linspace(0, 1, 100)
t1 = np.sin(2 * np.pi * x1)
# t = sin(2 \pi x) + gaussian noise
x2 = np.linspace(0, 1, 10)
t2 = np.sin(2 * np.pi * x2) + np.random.normal(0, 0.2, 10)
# figure 1.4

# fitting with polynomial functions
w0 = np.polyfit(x2, t2, 0)
w1 = np.polyfit(x2, t2, 1)
w3 = np.polyfit(x2, t2, 3)
w9 = np.polyfit(x2, t2, 9)

# visualize
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
axs[0, 0].set_title('M = 0')
axs[0, 1].set_title('M = 1')
axs[1, 0].set_title('M = 3')
axs[1, 1].set_title('M = 9')

for ax in axs.flat:
    ax.plot(x1, t1, color='green')
    ax.scatter(x2, t2, facecolor='None', edgecolor='blue')
    ax.set(xlabel='x', ylabel='t')

axs[0, 0].plot(x1, np.poly1d(w0)(x1), color='red')
axs[0, 1].plot(x1, np.poly1d(w1)(x1), color='red')
axs[1, 0].plot(x1, np.poly1d(w3)(x1), color='red')
axs[1, 1].plot(x1, np.poly1d(w9)(x1), color='red')

plt.tight_layout()

f:id:twakamori:20210219012715p:plain
figure 1.4

それっぽく再現できた. たったの10サンプルなら正規方程式を使って解析的に解いた方が早いかもしれないが,そろそろ次に進みたいのでこれは一旦これで良しとする.

PRML上巻 P5-6

昨日の続きをやっていく.

\displaystyle{
y(x, \mathbf{w}) = w_0 + w_1 x + w_2 x^2 + \cdots + w_M x^M = \sum_{j=0}^{M}w_j x^j \tag{1.1}
}

(1.1)多項式の字数Mを選ぶ問題は,モデル比較(model comparioson)あるいはモデル選択(model selection)と呼ばれる.

例として,次数M=0,1,3,9多項式を当てはめてみる. 以下,コード(めちゃくちゃ汚いので後日修正する).

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
# t = sin(2 \pi x)
x1 = np.linspace(0, 1, 100)
t1 = np.sin(2 * np.pi * x1)
# t = sin(2 \pi x) + gaussian noise
x2 = np.linspace(0, 1, 10)
t2 = np.sin(2 * np.pi * x2) + np.random.normal(0, 0.2, 10)
# figure 1.4

def gradientDescentM0(X, t, w, alpha, num_iters):
    """Performs gradient descent to learn `w`."""
    n = t.shape[0]
    w = w.copy()
    def y(X, w):
        return w[0] * X[:, 0]
    for i in range(num_iters):
        w0 = w[0] - alpha * (1 / n) * np.sum((y(X, w) - t) * X[:, 0])
        w = [w0]
    return w

def gradientDescentM1(X, t, w, alpha, num_iters):
    """Performs gradient descent to learn `w`."""
    n = t.shape[0]
    w = w.copy()
    def y(X, w):
        return w[0] * X[:, 0] + w[1] * X[:, 1]
    for i in range(num_iters):
        w0 = w[0] - alpha * (1 / n) * np.sum((y(X, w) - t) * X[:, 0])
        w1 = w[1] - alpha * (1 / n) * np.sum((y(X, w) - t) * X[:, 1])
        w = [w0, w1]
    return w

def gradientDescentM3(X, t, w, alpha, num_iters):
    """Performs gradient descent to learn `w`."""
    n = t.shape[0]
    w = w.copy()
    def y(X, w):
        return w[0] * X[:, 0] + w[1] * X[:, 1] + w[2] * np.power(X[:, 1], 2) + w[3] * np.power(X[:, 1], 3)
    for i in range(num_iters):
        w0 = w[0] - alpha * (1 / n) * np.sum((y(X, w) - t) * X[:, 0])
        w1 = w[1] - alpha * (1 / n) * np.sum((y(X, w) - t) * X[:, 1])
        w2 = w[2] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 2))
        w3 = w[3] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 3))
        w = [w0, w1, w2, w3]
    return w

def gradientDescentM9(X, t, w, alpha, num_iters):
    """Performs gradient descent to learn `w`."""
    n = t.shape[0]
    w = w.copy()
    def y(X, w):
        return w[0] * X[:, 0] + w[1] * X[:, 1] + w[2] * np.power(X[:, 1], 2) + w[3] * np.power(X[:, 1], 3) \
            + w[4] * np.power(X[:, 1], 4) + w[5] * np.power(X[:, 1], 5) + w[6] * np.power(X[:, 1], 6) \
            + w[7] * np.power(X[:, 1], 7) + w[8] * np.power(X[:, 1], 8) + w[9] * np.power(X[:, 1], 9)
    for i in range(num_iters):
        w0 = w[0] - alpha * (1 / n) * np.sum((y(X, w) - t) * X[:, 0])
        w1 = w[1] - alpha * (1 / n) * np.sum((y(X, w) - t) * X[:, 1])
        w2 = w[2] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 2))
        w3 = w[3] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 3))
        w4 = w[4] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 4))
        w5 = w[5] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 5))
        w6 = w[6] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 6))
        w7 = w[7] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 7))
        w8 = w[8] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 8))
        w9 = w[9] - alpha * (1 / n) * np.sum((y(X, w) - t) * np.power(X[:, 1], 9))
        w = [w0, w1, w2, w3, w4, w5, w6, w7, w8, w9]
    return w

# add w_0 and stack examples
X = np.stack([np.ones(x2.size), x2], axis=1)
y = t2
iterations = 100000
alpha = 0.1

w0 = gradientDescentM0(X, y, np.zeros(1), alpha, iterations)
w1 = gradientDescentM1(X, y, np.zeros(2), alpha, iterations)
w3 = gradientDescentM3(X, y, np.zeros(4), alpha, iterations)
w9 = gradientDescentM9(X, y, np.zeros(10), alpha, iterations)

fig, axs = plt.subplots(2, 2, figsize=(12, 8))
axs[0, 0].set_title('M = 0')
axs[0, 1].set_title('M = 1')
axs[1, 0].set_title('M = 3')
axs[1, 1].set_title('M = 9')

for ax in axs.flat:
    ax.plot(x1, t1, color='green')
    ax.scatter(x2, t2, facecolor='None', edgecolor='blue')
    ax.set(xlabel='x', ylabel='t')

def y3(X, w):
    return w[0] * X[:, 0] + w[1] * X[:, 1] + w[2] * np.power(X[:, 1], 2) + w[3] * np.power(X[:, 1], 3)
def y9(X, w):
    return w[0] * X[:, 0] + w[1] * X[:, 1] + w[2] * np.power(X[:, 1], 2) + w[3] * np.power(X[:, 1], 3) \
        + w[4] * np.power(X[:, 1], 4) + w[5] * np.power(X[:, 1], 5) + w[6] * np.power(X[:, 1], 6) \
        + w[7] * np.power(X[:, 1], 7) + w[8] * np.power(X[:, 1], 8) + w[9] * np.power(X[:, 1], 9)

# M = 0
axs[0, 0].plot(x1, w0[0] * np.ones(x1.size), color='red')
# M = 1
axs[0, 1].plot(x1, w1[0] * np.ones(x1.size) + w1[1] * x1, color='red')
# M = 3
axs[1, 0].plot(x1, y3(np.stack([np.ones(x1.size), x1], axis=1), w3), color='red')
# M = 9
axs[1, 1].plot(x1, y9(np.stack([np.ones(x1.size), x1], axis=1), w9), color='red')

plt.tight_layout()

f:id:twakamori:20210218031821p:plain
figure 1.4

M = 3の図とM = 9の図が書籍と異なり収束していないことがわかる. 最急降下法だと収束に時間がかかるので,別の方法を取ったほうが良さそうだ.

今日はここまで.

PRML上巻 P5

昨日の続きをやっていく.

\displaystyle{
y(x, \mathbf{w}) = w_0 + w_1 x + w_2 x^2 + \cdots + w_M x^M = \sum_{j=0}^{M}w_j x^j \tag{1.1}
}

この関数と

\displaystyle{
E(\mathbf{w})=\frac{1}{2}\sum_{n=1}^{N}\left\{y(x_n, \mathbf{w}) - t_n \right\}^2\tag{1.2}
}

この誤差関数について考える.

演習問題 1.1

関数y(x, \mathbf{w})多項式(1.1)で与えられたときの(1.2)の二乗和誤差関数を考える.この誤差関数を最小にする係数\mathbf{w}=\{w_i\}は以下の線形方程式の解として与えられることを示せ.

\displaystyle{
\sum_{j=0}^{M}A_{ij}w_j=T_i\tag{1.122}
}

ただし,

\displaystyle{
A_{ij}=\sum_{n=1}^{N}(x_n)^{i+j}, T_i=\sum_{n=1}^{N}(x_n)^{i}t_n \tag{1.123}
}

ここで,下付き添字のijは成分を表し,(x)^ ixi乗を表す.

演習問題 1.1 解答

(1.2)に式(1.1)を代入する.

\displaystyle{
E(\mathbf{w})=\frac{1}{2}\sum_{n=1}^{N}\left\{\sum_{j=0}^{M}w_j(x_n)^j - t_n\right\}^2
}

(1.2)が最小のとき,上式のw_iについての偏微分が0となる.


\begin{aligned}
\frac{\partial}{\partial w_i}E(\mathbf{w})&=2\cdot\frac{1}{2}\sum_{n=1}^{N}\left\{\sum_{j=0}^{M}w_j(x_n)^j-t_n\right\}(x_n)^i \\
&=\sum_{j=0}^{M}\sum_{n=1}^{N}(x_n)^{(i+j)}w_j-\sum_{n=1}^{N}(x_n)^it_n \\
&=\sum_{j=0}^{M}A_{ij}w_j-T_i \\
&=0
\end{aligned}

上式のT_iを移項すると,式(1.122)が得られる.

\displaystyle{
\sum_{j=0}^{M}A_{ij}w_j=T_i
}

本日は以上.(はてなで数式をいじっていたら終わった)

PRML上巻 P1-5

とりあえず, PRML を頭から読み直すやつをやることにした.

今日は「1.1 多項式曲線フィッティング」をやってみる.

訓練集合として,N個の観測値xを並べた\mathbf{x}\equiv(x_1,\ldots,x_N)^ \mathrm{T}と, それぞれに対応する観測値tを並べた\mathbf{t}\equiv(t_1,\ldots,t_N)^ \mathrm{T}が与えられる.

我々の目標は,この訓練集合を利用して,新たな入力変数の値\hat{x}に対して目標変数\hat{t}の値を予測することである.

まず,関数  sin(2 \pi x) にガウシアンノイズを加えて,N = 10 個の訓練データを生成する.

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
# t = sin(2 \pi x)
x1 = np.linspace(0, 1, 100)
t1 = np.sin(2 * np.pi * x1)
plt.plot(x1, t1, color='green')

# training data set
# t = sin(2 \pi x) + gaussian noise
x2 = np.linspace(0, 1, 10)
t2 = np.sin(2 * np.pi * x2) + np.random.normal(0, 0.2, 10)
plt.scatter(x2, t2, facecolor='None', edgecolor='blue')

plt.xlabel('x')
plt.ylabel('t')

f:id:twakamori:20210216020614p:plain
figure 1.2

この訓練データに対して,以下のような多項式をフィッティングする.

\displaystyle{
y(x, \mathbf{w}) = w_0 + w_1 x + w_2 x^2 + \cdots + w_M x^M = \sum_{j=0}^{M}w_j x^j \tag{1.1}
}

この多項式は線形モデル(linear model)と呼ばれる.

訓練データに多項式をあてはめることで係数の値を求める. これは,誤差関数(error function)の最小化で達成できる. 誤差関数には以下の二乗和誤差(sum-of-squares error)を用いる.

\displaystyle{
E(\mathbf{w})=\frac{1}{2}\sum_{n=1}^{N}\left\{y(x_n, \mathbf{w}) - t_n \right\}^2\tag{1.2}
}

E(\mathbf{w})は非負であり,y(x, \mathbf{w})が全訓練データ点をちょうど通るとき,0になる. このように,E(\mathbf{w})をできるだけ小さくするような\mathbf{w}を選ぶことで曲線あてはめ問題を解くことができる.

今日はここまで. (100年くらいかかりそう)

1日15分振り返りの時間を作ると良いらしい

元ネタは以下の記事.

www.forbes.com

これによると,

Research has shown that when employees spent just 15 minutes per day reflecting on what they learned that day, they began to perform 23% better after just 10 days.

とのことらしい.

そこで,明日から学習内容を振り返ってみることにする. なるべく継続できるよう,内容とボリュームに制約は設けないことにする. さて,いつまで続けられるか...

Task-Relevant Adversarial Imitation Learning を読んだ

この記事は強化学習苦手の会 Advent Calendar 2020の22日目の記事です.

本記事では,Task-Relevant Adversarial Imitation LearningというDeepMindの論文を紹介します. この論文はCoRL2020のSpotlight Talkに選ばれたようです. 本記事の構成は落合先生の論文まとめフォーマットに従っています.

なお私は強化学習が本当に苦手なので,論文をちゃんと理解できている自信がありません.有識者の皆様からのご指摘を心待ちにしています.

どんなもの?

敵対的模倣学習(GAIL)のパフォーマンスを低下させる原因として,GAILのDiscriminator(識別器)がspurious associations(本記事ではこれを疑似関連と呼ぶことにします)に引っ張られてしまうことを指摘し,それを解決する手法であるTask-Relevant Adversarial Imitation Learning (TRAIL)を提案しています.


模倣学習とは,エキスパートの方策を模倣する方策を学習する問題です.模倣学習手法の一つであるGAILは,敵対的生成ネットワーク(GANs)の枠組みで,エキスパートの方策の生成する状態行動対と,Generator(生成器)の生成する状態行動対を区別する識別器を敵対的に学習します.こうすることで,報酬関数を陽に定義することなく,エキスパートの状態行動対の集合(軌跡)からターゲットの方策を獲得できます.

疑似関連とは何かというと,「エキスパートのラベル」と「タスク無関係の特徴」との間に生じる疑似相関のようなものです.イメージ図が論文中のFigure 1にありますので引用します.

f:id:twakamori:20201221051518p:plain
疑似関連のイメージ (元論文のFigure 1より引用)

図の縦軸はタスクに関係する特徴,横軸はspuriousな特徴(疑似特徴)を表しています.緑色の点はエキスパートの観測,赤色は学習中のエージェントの観測を表します. 学習が進むにつれてエージェントはエキスパートに近づこうとするため,生成器はタスクに関連する特徴を出力するようになります(学習の経過は赤色の濃さで表されています). その結果,識別器の決定境界(図中の青色の破線)は,タスクに関連する特徴の代わりに,擬似特徴に依存してエキスパートとエージェントを区別するように学習されてしまいます. このような識別器がタスク無関係の特徴に依存してしまう問題を,この論文では疑似関連と呼んでいます.

論文では疑似関連の例として,画像を状態とする制御タスクにおける初期状態や画像の輝度を挙げています.エージェントはそれらを制御できないため,いかなる方策を生成しても識別器を騙すことができず,その結果,生成器の学習が進まなくなります.

先行研究と比べてどこがすごい?

Stadieらの先行研究*1は,画像において,エキスパートとエージェントの観測が異なる視点から来るような三人称の問題にGAILを適用しました.先行研究は識別器が視点を利用しないようにするために,gradient flipping*2というテクニックを用いてドメイン不変の特徴を学習します.しかし,本論文の著者らによると,Stadieらの先行研究は疑似関連の問題に対してドメイン不変の特徴をうまく抽出できなかったそうです.

また,疑似関連の問題は,識別器の過学習の問題と似ています.先行研究*3*4*5は,GAILの識別器の過学習を防ぐために正則化を行いますが,前述のような画像特徴の疑似関連をもつ場合,正則化ではそれを防ぐことができず,場合によっては疑似関連を強調してしまうことさえあります.

技術や手法のキモはどこ?

提案手法(TRAIL)は,constraining sets(制約集合) \mathcal{I} _ E \mathcal{I} _ Aを導入します. \mathcal{I} _ E \mathcal{I} _ Aは,それぞれエキスパートの観測の集合とエージェントの観測の集合を表します. \mathcal{I} _ E \mathcal{I} _ Aを識別する特徴を疑似特徴として区別することで,識別器がタスクに関連する特徴をうまく捉えることができるようになります.


f:id:twakamori:20201224033707p:plain
 \mathcal{I} _ E \mathcal{I} _ Aの導入イメージ (元論文のFigure 3より引用)

上図の左側は,冒頭の図で学習された決定境界がタスクに関連する境界と擬似的な境界の組み合わせであることを示しています.上図の右側のように, \mathcal{I} _ E \mathcal{I} _ Aを導入し,これらを識別する擬似的な境界を学習し,識別器がタスクに関連する境界をうまく捉えられるようにすることが本研究の課題となります.

手順は以下のとおりです.まず,GAILの識別器の目的関数を変形して,状態行動対 (s, a)の代わりに状態 (s)のみを扱うようにします(状態のみを扱う模倣学習は先行研究*6があります).

\displaystyle{
\underset{\psi}{\max}\mathbb{E}_{s\sim\pi_E}\left[\log D_\psi(s)\right]+\mathbb{E}_{s\sim\pi_A}\left[\log \left(1-D_\psi(s)\right)\right] \tag{1}
}

ここで, \pi _ A \pi _ Eはそれぞれエージェントの方策とエキスパートの方策を表し, Dは識別器のネットワークを表します.学習した識別器から報酬関数が得られます(本研究ではシンプルに R(s) = -\log(1-D _ \psi(s))と置いています). \pi _ Aはこの報酬関数を最大化するよう学習されます.

TRAILの識別器は,式 (1)と同様に交差エントロピーを用いますが, \mathcal{I} _ E \mathcal{I} _ Aからの観測に対して制約をかけます.具体的には, \mathcal{I} _ E \mathcal{I} _ Aからの観測に対して, accuracy(\mathcal{I} _ E, \mathcal{I} _ A)\geq\frac{1}{2}の場合は負の制約をかけます.ここで, accuracy(\mathcal{I} _ E, \mathcal{I} _ A)は,2つの制約集合からの観測の均等な集合における平均識別器精度として次式で定義されます.

\displaystyle{
accuracy(\mathcal{I}_E, \mathcal{I}_A)=\frac{1}{2}\mathbb{E}_{s\in\mathcal{I}_E}\left[\mathbf{1}_{D_\psi(s)\geq\frac{1}{2}}\right]+\frac{1}{2}\mathbb{E}_{s\in\mathcal{I}_A}\left[\mathbf{1}_{D_\psi(s)\lt\frac{1}{2}}\right] \tag{2}
}

直感的には,この制約は「識別器が疑似特徴を識別・利用することを抑制し,もし識別器の学習に使われた場合はそれを忘却させる」という働きをします.

エキスパートとエージェントからの Nサンプルのバッチをそれぞれ s _ E \sim \pi _ E, s _ A \sim \pi _ A,制約観測のサンプルを  \hat{s} _ E \subset \mathcal{I} _ E, \hat{s} _  A \subset \mathcal{I} _ Aとすると,TRAILの識別器は以下を最大化します.

\displaystyle{
\mathcal{L}_\psi(s_E,s_A,\hat{s}_E,\hat{s}_A)=G_\psi(s_E,s_A)-\mathbf{1}_{accuracy(\hat{s}_E,\hat{s}_A)\geq\frac{1}{2}}G_\psi(\hat{s}_E,\hat{s}_A)
}

ここで, G _ \psi(s _ E,s _ A) D _ \psiに対するGAILの識別器の損失のサンプルによる推定値を表し,次式で定義されます.

\displaystyle{
G_\psi(s_E,s_A)=\sum_{i=1}^{N}\log D_\psi\left(s_E^{(i)}\right)+\log \left[1-D_\psi\left(s_A^{(i)}\right)\right] \tag{3}
}

また, \mathbf{1} _ {accuracy(\hat{s} _ E,\hat{s} _ A)\geq\frac{1}{2}}は指示関数であり,制約を違反するかどうかを表します.

以上が手法の全体像となります. さて,ここで大きな疑問として「 \mathcal{I} _ E \mathcal{I} _ Aをどうやって用意するのか」という点が挙げられます. これについて,論文では以下のように述べられていました.

  • エキスパートとエージェントの観測に非定常な分布シフトがある場合:不可能
  • そうでない場合:
    • エキスパートとエージェントの両方の設定でランダムな方策を実行する…ランダム方策はタスク依存の情報を持たないため有効
    • エキスパートとエージェントのエピソードの初期のフレームを使用する...初期のフレームはタスク依存の行動がほとんど見られないため有効

どうやって有効だと検証した?

シミュレータ上のロボットマニピュレーションタスク(積み上げ,持ち上げ等)で他手法(GAIL,GAIL+AES,D4PGfD)と性能を比較し,エキスパートとエージェントの見た目が異なるタスクにおいて,TRAILが最も良い性能であることを確認しています.


実験設定の詳細については説明を割愛します.本研究で比較対象としているアルゴリズムは以下です.

  • BC: Behavior Cloning
  • GAIL: ベースライン(式 (1)の目的関数で識別器を訓練)
  • TRAIL: 提案手法(ベースラインの識別器の目的関数を式 (3)に変えたもの+early stopping)
  • GAIL+AES: ベースライン+early stopping
  • D4PGfD: D4PGによるエージェントの学習において,経験再生にエキスパートの軌跡を加えたもの(報酬関数には別途定義したスパースな報酬関数を用いる)

1つ目の実験は,Lift(持ち上げ),Stack(積み上げ),Stack banana(バナナ形状のオブジェクトの積み上げ),Insertion(挿入)の4つのマニピュレーションタスクにおいて,各アルゴリズムを比較しています.この実験では,エキスパートとエージェントの観測の間で初期状態や輝度などを特にいじっていないので,疑似特徴がどのような特徴を捉えるかは分かりません.

f:id:twakamori:20201224220458p:plain
マニピュレーションタスクにおける性能比較結果 (元論文のFigure 5より引用)

結果は上図のとおりで,TRAILは唯一すべてのタスクでエキスパートの性能に到達しています.(1つのアルゴリズムに対して複数の線が描画されていて分かりづらいですね…)

2つ目の実験は,Liftタスクにおいて視覚的な疑似特徴を挿入した場合の性能を評価しています.具体的には,グリップの色について,エキスパートは明るい色,エージェントは暗い色にそれぞれ設定し,両者に変化をつけています.

f:id:twakamori:20201224221859p:plain
視覚的な疑似特徴がある場合の性能比較 (元論文のFigure 6より引用)

結果は上図のとおりです.こちらは,GAIL+AESとTRAILの2つで比較していて,TRAILの性能がGAIL+AESをはっきりと上回り,エキスパートの性能に達していることが分かります.両者のアルゴリズムの差分は制約集合の有無だけですので,制約集合の導入が性能向上に寄与したということが分かります.

3つ目の実験は,2つ目の実験と同じようにエキスパートとエージェントの観測に変化を付けますが,今度はグリップの色ではなく,背景にタスクと関係ないブロックを配置します.ここでは,Reedら*7によって提案された,ランダムに初期化された畳み込みネットワークあるいは畳み込みcriticネットワークを使用して視覚的な特徴を捉え,その上に小さな識別器を訓練する手法と比較します(それぞれ,GAIL+random, GAIL+criticと呼びます).

f:id:twakamori:20201224224115p:plain
背景にタスクと関係ないブロックがある場合の性能比較 (元論文のFigure 7より引用)

結果は上図のとおり,通常のLiftタスクは全ての手法でそこそこ良い性能が出ています(とはいえ,GAIL+AESとTRAILの性能が最も良い)が,背景にタスクと関係ないブロックがある場合(Lift distracted)は,TRAILのみがエキスパートの性能を達成しています.このように,明らかにタスクと関係ない(グリップの色のような)特徴でなくても,TRAILは疑似特徴を見つけ出し,区別することができているようです.

3つ目の実験では,制約集合を構築するために初期フレームを使用していました.4つ目の実験は,制約集合を構築するもう一つの戦略である,ランダム方策を用いる戦略について評価します.制約集合を構築する2つの戦略の違いを明らかにするために,2つ目の実験と3つ目の実験を組み合わせた難しいタスク(グリップの色が異なる+背景にタスクと関係ないブロックがある)を用います.

f:id:twakamori:20201224225240p:plain
制約集合を構築する2つの戦略の性能比較 (元論文のFigure 8より引用)

結果は上図のとおり,このような困難なタスクではTRAILも性能向上に時間がかかっています.また2つの戦略の性能はよく似ています.このため,著者らは追加データ収集の必要のない,初期のフレームを使用する方法を推奨しています.

議論はある?

特にないようですが,early stoppingや汎化性能,最適なDiscriminatorを用いた場合との比較等について,ablation studiesを実施しています. また,supplementary materialに実験設定やデータ拡張などのチューニング,ネットワーク構造やハイパーパラメータが詳細に記載されています.

次に読むべき論文は?


以上,Task-Relevant Adversarial Imitation LearningというDeepMindの論文を紹介いたしました.

個人的な所感としては,タスクが画像に限定されているものの,問題提起やアプローチがシンプルで分かりやすく,実装しやすそうな点が良いと思いました(本当は実装しようと思ったんですが,実装力と時間が足りなくて間に合いませんでした). ただ,GAILの改善系アルゴリズムは複数提案されているにもかかわらず,素のGAILとしか比較していないところがちょっと気になりました(実際,そのような理由でICLR2020をRejectされているようでした).

PyMCによるベイズ推論

本記事では以下のレポジトリの Ch1_Introduction_PyMC3.ipynb の内容を一部抜粋して実行してみる.

github.com

まず,データをダウンロードして表示する.

import numpy as np
import matplotlib.pyplot as plt
import os

# 'data' ディレクトリを作成する
makedirs('data', exist_ok=True)

# サンプルデータ(ユーザが受信するメッセージ数)をダウンロードして 'data/txtdata.csv' に配置する
from urllib.request import urlretrieve
urlretrieve('https://git.io/vXTVC', 'data/txtdata.csv')

# 可視化
figsize(12.5, 3.5)
count_data = np.loadtxt("data/txtdata.csv")
n_count_data = len(count_data)
plt.bar(np.arange(n_count_data), count_data)
plt.xlabel("Time (days)")
plt.ylabel("count of text-msgs received")
plt.title("Did the user's texting habits change over time?")
plt.xlim(0, n_count_data)

f:id:twakamori:20200703000910p:plain
サンプルデータ(ユーザが受信するメッセージ数)

このような計数データはポアソン分布を使ってモデリングする. i 日目のメッセージ数をC_iとすると,


C_i \sim \mathrm{Poi}(\lambda)

ここで,パラメータ\lambdaをどうやって決めるか. グラフ後半のほうで数値が大きくなるようにみえる. そこで,ある日\tauを境に,パラメータ\lambdaが突然大きくなると仮定する.

 \lambda = \left\{
\begin{array}{ll}
\lambda_1 & (t \lt \tau ) \\
\lambda_2 & (t \geq \tau ) \\
\end{array}
\right.

この2つの未知数\lambda_1,\lambda_2について,ベイズ推論による推定を試みる. ベイズ推論を使うには,2つの\lambdaに対して事前分布を決める必要がある. 正の実数のための確率分布として,指数分布を使うのがちょうどよい.

指数分布のパラメータを\alphaとおくと,

 \begin{align}
\lambda_1 &\sim \mathrm{Exp}(\alpha) \\
\lambda_2 &\sim \mathrm{Exp}(\alpha)
\end{align}

とかける.

\tauについては,変化点が何日目かを判断するのは難しい.そこで,ここでは一様分布を使う.

 \begin{align}
\tau &\sim \mathrm{DiscreteUniform}(1, 70) \\
& \Rightarrow P(\tau = k) = \frac{1}{70}
\end{align}

以上の確率変数 (\tau, \lambda_1, \lambda_2) について,PyMCを使って推定する.

まず,\lambda_1, \lambda_2に対応するPyMC変数を作成する.

import pymc3 as pm
import theano.tensor as tt

with pm.Model() as model:
    alpha = 1.0/count_data.mean()  # count_data: メッセージ受信数

    lambda_1 = pm.Exponential("lambda_1", alpha)
    lambda_2 = pm.Exponential("lambda_2", alpha)
    
    tau = pm.DiscreteUniform("tau", lower=0, upper=n_count_data - 1)

次に,関数 lambda_ を定義する.実際にはこれを確率変数 \lambda とみなすことができる.

with model:
    idx = np.arange(n_count_data) # Index
    lambda_ = pm.math.switch(tau > idx, lambda_1, lambda_2)

次に, count_datalambda_ を受け取り,この例題のデータ生成モデルであるポアソン分布のオブジェクトを生成し,変数 observation に代入する.

with model:
    observation = pm.Poisson("obs", lambda_, observed=count_data)

以上のコードに対して,マルコフ連鎖モンテカルロ法(MCMC)を使った学習(事後分布からのサンプリング)を行う.

with model:
    step = pm.Metropolis()
    trace = pm.sample(10000, tune=5000,step=step)

出力は以下のようになる.

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>Metropolis: [tau]
>Metropolis: [lambda_2]
>Metropolis: [lambda_1]
Sampling 4 chains, 0 divergences: 100%|██████████| 60000/60000 [00:08<00:00, 7404.58draws/s]
The number of effective samples is smaller than 25% for some parameters.

サンプル値のヒストグラムをプロットして,事後分布の形状を確かめる.

lambda_1_samples = trace['lambda_1']
lambda_2_samples = trace['lambda_2']
tau_samples = trace['tau']

figsize(12.5, 10)
#histogram of the samples:

ax = plt.subplot(311)
ax.set_autoscaley_on(False)

plt.hist(lambda_1_samples, histtype='stepfilled', bins=30, alpha=0.85,
         label="posterior of $\lambda_1$", color="#A60628", density=True)
plt.legend(loc="upper left")
plt.title(r"""Posterior distributions of the variables
    $\lambda_1,\;\lambda_2,\;\tau$""")
plt.xlim([15, 30])
plt.xlabel("$\lambda_1$ value")

ax = plt.subplot(312)
ax.set_autoscaley_on(False)
plt.hist(lambda_2_samples, histtype='stepfilled', bins=30, alpha=0.85,
         label="posterior of $\lambda_2$", color="#7A68A6", density=True)
plt.legend(loc="upper left")
plt.xlim([15, 30])
plt.xlabel("$\lambda_2$ value")

plt.subplot(313)
w = 1.0 / tau_samples.shape[0] * np.ones_like(tau_samples)
plt.hist(tau_samples, bins=n_count_data, alpha=1,
         label=r"posterior of $\tau$",
         color="#467821", weights=w, rwidth=2.)
plt.xticks(np.arange(n_count_data))

plt.legend(loc="upper left")
plt.ylim([0, .75])
plt.xlim([35, len(count_data)-20])
plt.xlabel(r"$\tau$ (in days)")
plt.ylabel("probability")

f:id:twakamori:20200703234646p:plain
事後分布

図から分かることは以下.

  • 推定値の不確実さ
    • 分布の幅が広ければ,事後信念はあまり確信できるものではない
  • パラメータの妥当な値
    • \lambda_1はおよそ18,\lambda_2はおよそ23
  • 2つの\lambdaの事後分布が明らかに異なっている
    • ユーザのメッセージ受信に変化があった可能性が高い
  • \tauの分布
    • 45日目にユーザが振る舞いを変えた確率が50%程度

最後に,サンプルデータにメッセージ数の期待値を重ねて表示する.

figsize(12.5, 5)

N = tau_samples.shape[0]
expected_texts_per_day = np.zeros(n_count_data)
for day in range(0, n_count_data):
    ix = day < tau_samples
    expected_texts_per_day[day] = (lambda_1_samples[ix].sum()
                                   + lambda_2_samples[~ix].sum()) / N


plt.plot(range(n_count_data), expected_texts_per_day, lw=4, color="#E24A33",
         label="expected number of text-messages received")
plt.xlim(0, n_count_data)
plt.xlabel("Day")
plt.ylabel("Expected # text-messages")
plt.title("Expected number of text-messages received")
plt.ylim(0, 60)
plt.bar(np.arange(len(count_data)), count_data, color="#348ABD", alpha=0.65,
        label="observed texts per day")

plt.legend(loc="upper left")

f:id:twakamori:20200703235929p:plain
受信メッセージ数とその期待値