Variational Intrinsic Successor Features (VISR)
Contents
Variational Intrinsic Successor Features (VISR)#
제목: Fast Task Inference with Variational Intrinsic Successor Features
저자: Hansen, Steven, Will Dabney, Andre Barreto, Tom Van de Wiele, David Warde-Farley, and Volodymyr Mnih
연도: 2020년
학술대회: ICLR
키워드: Unsupervised RL, Skill discovery, Successor features
Summary: Successor features를 이용한 unsupervised skill discovery 알고리즘
Successor Feature (SF) and Universal Successor Feature (USF)#
Successor feature (SF)를 다루는 연구에서는 보상함수 \(r(s, a)\)가 다음과 같이 표현될 수 있다고 가정한다.
현재 상태 \(s\)와 행동 \(a\) 그리고 다음 상태에서 어떤 정보를 뽑아서 특징 벡터 \(\phi(s, a, s')\)를 만들고, 이 특징 벡터의 원소들에 가중치를 줘서 더한 것으로 보상함수가 정의된다는 것이다. 예를 들어, 보행하는 로봇 에이전트를 학습하고 싶다고 할 때, 관절들의 위치와 속도 그리고 관절에 가해진 토크 등에 적절하게 가중치를 부여하여 보상함수를 정의하는 것을 떠올린다면 위 가정이 그렇게 나쁜 가정은 아니다.
그리고 특징 벡터가 고정되었을 때, 가중치 벡터 \(\mathbf{w}\)를 어떻게 주느냐에 따라서 보상함수가 달라지게 된다. 따라서 벡터 \(\mathbf{w}\)를 task vector라고 부른다. 보행 로봇 에이전트 예시에서 보행 방향에 양의 가중치를 주면 앞으로 가는 보행 로봇이 학습될 것이고, 음의 가중치를 주면 뒤로 가는 보행 로봇이 학습될 것이다. 앞으로 task vector \(\mathbf{w}\)에 의해 결정된 보상함수를 \(r_{\mathbf{w}}\)라고 표기할 것이다.
주어진 정책 \(\pi\)의 보상함수 \(r_{\mathbf{w}}\)에 대한행동가치함수 \(Q_{\mathbf{w}}^\pi(s, a)\)의 정의에 식 (1)을 대입하면 다음과 같이 정리할 수 있다.
이제 정책 \(\pi\)의 SF \(\psi: \mathcal{S} \times \mathcal{A} \rightarrow \mathbb{R}^d\)를 다음과 같이 정의해보자.
그러면 정책 \(\pi\)의 행동가치함수는 다음과 같이 간단히 표현될 수 있다.
따라서 정책 \(\pi\)의 SF \(\psi^{\pi}(s, a)\)만 알고 있다면, task vector \(\mathbf{w}\)만 바꿔줌으로써 보상함수 \(r_{\mathbf{w}}\)에 대한 행동가치함수를 계산할 수 있게 된다.
Universal Successor Features (USF)#
SF \(\psi^{\pi}\)는 정책에 의존적이다. 즉, 정책이 달라지면 \(\psi^{\pi}\) 함수를 다시 찾아야 한다. 정책 업데이트가 계속 발생하는 상황에서 SF를 사용하긴 어려울 것이다. 따라서 우리는 하나의 함수가 상태와 행동 뿐만 아니라 정책에 대한 정보를 입력 받아서 해당 정책의 SF를 출력해주기를 바랄 것이다. 각 정책을 \(k\)-차원 벡터로 인코딩해주는 함수 \(e: \Pi \rightarrow \mathbb{R}^k\)가 존재한다고 하자. 여기서 \(\Pi\)는 모든 정책의 집합이다. Universal Successor Feature (USF) \(\psi: \mathcal{S} \times \mathcal{A} \times \mathbb{R}^k \rightarrow \mathbb{R}^d\)은 다음과 같이 정의된다.
\(\psi(s, a, e(\pi))\)을 찾아낼 수만 있다면, 어떤 정책 \(\pi\)와 어떤 보상함수 \(r_{\mathbf{w}}\)가 주어지더라도, 그것의 행동가치함수를 빠르게 찾아낼 수 있다.
구현 관점에서 봤을 때, 우리는 하나의 정책이 task vector \(\mathbf{w}\)가 주어질 때마다 보상함수 \(r_{\mathbf{w}}\)에 알맞게 행동하는 것을 원할 것이다. 이를 위하여 우리는 task vector \(\mathbf{w}\)에 conditioned된 정책 \(\pi(a|s,\mathbf{w})\)를 사용한다. 그리고 정책 인코딩 함수를 \(e:\pi(a|s, \mathbf{w}) \mapsto \mathbf{w}\)으로 정의한다. 정책 \(\pi(a|s, \mathbf{w})\)의 보상함수 \(r_\mathbf{w}\)에 대한 행동가치함수 \(\psi(s, a, \mathbf{w})^\top \mathbf{w}\)를 계산하여 Q-learning 또는 actor-critic 방식으로 정책을 찾아가게 된다.
Unsupervised Skill Discovery (USD)#
Unsupervised skill discovery (USD) 분야는 보상함수가 주어지지 않은 환경에서 latent vector \(z \in \mathbb{R}^d\)를 입력 받는 정책 \(\pi(a \mid s, z)\)을 찾는다. 이때, 선택된 latent vector \(z\)에 따라서 \(\pi(a \mid s, z)\)의 behavior가 달라지는 것을 원한다. 즉, 서로 다른 두 latent vector \(z_1\)과 \(z_2\)에 대하여 두 정책 \(\pi(a \mid s, z_1)\)과 \(\pi(a \mid s, z_2)\)이 만들어내는 trajectories를 각각 \(\tau_{\pi_{z_1}}\), \(\tau_{\pi_{z_2}}\)라고 할 때, \(\tau_{\pi_{z_1}}\)와 \(\tau_{\pi_{z_2}}\)의 어떤 statistics (function of random variables) \(f(\tau_{\pi_{z_1}})\)과 \(f(\tau_{\pi_{z_2}})\)가 서로 다른 분포를 갖기를 바란다. 이를 반대로 말하자면, 만약 한 latent vector \(z\)가 주어지면, \(f(\tau_{\pi_{z}})\)의 분포를 예측할 수 있게 된다는 것이다. 이 목표는 latent variable \(z\)와 \(f(\tau_{\pi_{z}})\) 사이의 mutual information을 최대화하여 달성할 수 있다.
여기서 \(f\)는 주로 trajectory에서 샘플링한 상태 \(s\)를 사용한다. DIAYN 이전 연구에서는 \(f\)로 초기 상태와 마지막 상태의 순서쌍 \((s_0, s_T)\)을 사용하기도 한다. VISR에서는 전자를 사용한다. 그러면 사실상 위 mutual information은 다음과 같다.
위를 최대화하는 매개변수화된 정책 \(\pi_\theta(a \mid s, z)\)를 찾으면 된다. 보통 \(z\)는 균등분포 또는 정규분포 등의 고정된 사전분포 \(p(z)\)에서 샘플링된다고 가정하기 때문에 첫 번째 텀은 \(\theta\)에 무관하다. 따라서 위 mutual information을 최대화하기 위해서는 \(-\mathcal{H}(z | s)\)을 최대화하면 된다.
기댓값을 표본평균으로 근사하기 위해서는 확률변수 \(s\)와 \(z\)의 joint distribution에서 샘플링을 해야 한다. 이는 사전분포 \(p(z)\)에서 \(z\)를 샘플링한 후 해당 \(z\)를 입력 받는 정책 \(\pi_\theta(a|s,z)\)으로 rollout하여 상태 \(s\)를 방문하는 것을 통해 구현될 수 있다.
Variational inference를 이용한 목적함수 최대화#
하지만 \((s, z)\)를 샘플링했다고 해서 사후분포 \(p(z|s)\) 값을 계산할 수 있는 것은 아니다. 실제 사후분포 대신 우리가 계산할 수 있는 매개변수화된 분포 \(q_\omega(z | s)\)를 사용하면 다음과 같은 부등식이 성립한다. \(\omega\)는 학습 가능한 파라미터이다.
부등식의 우변을 variational lower bound라고 부르고, 사후분포 대신 사용하는 \(p_\omega(z|s)\)를 variational distribution이라고 부른다.
참고로 \(\omega\)에 대해 우변을 최대화하면 variational distribution \(q_\omega(z|s)\)은 사후분포 \(p(z|s)\)에 KL divergence 관점에서 가까워지고, KL Divergence가 0이 됐을 때 \(\theta\)에 대해 우변을 최대화하면 좌변도 함께 최대화된다. 즉, variational lower bound를 최대화하면 우리의 목적함수도 최대화되는 것이다.
샘플을 통한 variational lower bound 최대화#
우리는 2개의 파라미터 \(\theta\)와 \(\omega\)에 대해 variationa lower bound을 최대화시킬 것이다. 우선 \((s, z)\) 샘플링은 위에서 말했던 것처럼 사전분포 \(p(z)\)에서 \(z\)를 샘플링한 후 해당 \(z\)를 입력 받는 정책 \(\pi_\theta(a|s,z)\)으로 rollout하게 된다. Rollout을 통해 모아 놓은 \( \left\{ (s_i, a_i, z_i) \right\}_{i=1}^{K}\) 샘플들로 다음 표본평균을 계산하게 된다.
위 목적함수를 \(\omega\)에 대해서 최대화하는 것은 쉽다. 구한 표본평균을 \(\omega\)에 대해서 미분하면 되기 때문이다. 하지만, 정책 파라미터 \(\theta\)에 대해서 목적함수를 최대화하는 것은 불가능하다. 위에서 정책 파라미터 \(\theta\)가 관여한 곳은 정책을 rollout하여 방문한 상태 \(s\)들에 녹아있기 때문이다.
한편, model-free 강화학습 알고리즘들은 보상함수를 모르더라도 보상함수의 기댓값을 최대화하는 정책을 찾을 수 있다. 각 transition \((s, a, z)\)에 대하여 보상함수를 다음과 같이 정의해보자.
위 보상함수에 model-free 강화학습 알고리즘을 적용하면, 우리의 목적함수 \(\mathcal{L}(\theta, \omega)\)를 최대화하는 정책을 찾을 수 있게 된다. 환경에서 정의하는 보상함수가 아닌 강화학습 알고리즘 자체가 특정 목적을 달성하기 위해 정의한 위와 같은 보상함수를 intrinsic reward라고 부른다.
여기까지가 USD 분야에서 많이 사용되는 목적함수를 알아보았다. 이 목적함수를 실제로 어떻게 구현하는지에 따라서 다양한 skill discovery 알고리즘이 된다. 다음으로 VISR이 위 목적함수를 어떻게 구현하는지에 대해서 알아보자.
Variational Intrinsic Successor Features (VISR)#
이제 SF를 이용하여 USD를 수행하는 VISR 알고리즘에 대해 알아볼 것이다. 먼저, USD를 설명할 때 관례를 따라 latent vector를 \(z \in \mathbb{R}^d\)라고 표기했는데, 이를 싹 다 \(\mathbf{w}\in\mathbb{R}^d\)로 표기하자. SF에서는 보상함수를 다음과 같이 표현된다고 가정했었다.
조금 더 간결해질 수 있도록 보상함수가 현재 상태에만 의존한다고 가정하자.
한편, USD에서는 intrinsic reward를 다음과 같이 정의했다.
VISR에서는 이 두 개를 같다고 등식을 세우면 다음과 같다.
VISR에서는 위 등식을 만족시키기 위하여 variational distribution \(q_\omega\)을 von Mises–Fisher (VMF) distribution으로 모델링하게 된다.
VMF 분포는 \(d\)-차원 unit random vector \(\mathbf{x}\)에 대해 다음과 같이 정의된다.
이때 \(\kappa \ge 0\)와 \(\mathbf{\mu}\)는 분포의 매개변수이다. \(\mathbf{\mu}\)와 \(\mathbf{x}\)는 norm이 1인 unit vector이어야 한다. \(C_d(\kappa)\)는 확률밀도함수를 적분했을 때 1이 되기 위한 normalization 상수이다. VISR에서는 \(\kappa=1\)인 VMF 분포를 variational distribution \(q_\omega(\mathbf{w} | s)\)로 사용하게 된다. 따라서 variational distribution의 모델 파라미터는 \(\omega=\phi(s)\)가 된다.
한편, SF에서 정책 인코딩을 \(e:\pi(a|s, \mathbf{w}) \mapsto \mathbf{w}\)으로 정의했고, 보상함수 \(r_\mathbf{w}\)에 대한 행동가치함수는 USF를 통해 \(Q^{\pi}_{\mathbf{w}}(s, a) = \psi(s, a, \mathbf{w})^\top \mathbf{w}\)으로 빠르게 계산할 수 있었다. 이는 정책 네트워크가 \(\mathbf{w}\)를 입력받는 것으로 구현된다. 즉, \(\pi_\theta(a \mid s, \mathbf{w})\)을 사용한다. 행동가치함수 \(Q^{\pi}_{\mathbf{w}}(s, a)\)를 critic으로 사용하여 actor-critic 방식으로 정책 네트워크를 업데이트하게 된다.
수도 알고리즘#
VISR 논문에서는 Atari 환경에서 실험을 했기 때문에 Q-learning 기반의 수도 알고리즘이 제공된다. 필자는 continuous action space에 더 관심이 많기 때문에 off-policy actor-critic 기반의 수도 알고리즘을 아래와 같이 적어보았다. 아래 수도 알고리즘을 따랐을 때 학습이 잘 될지는 모르겠다. 흐름만 참고하면 좋을 것 같다.
뉴럴 네트워크 초기화: \(\pi_\theta\), \(\phi\), \(\psi\)
Replay buffer 초기화: \(\mathcal{D} \leftarrow \emptyset\)
매 에피소드마다 다음을 진행
Task vector 샘플링 및 정규화: \(\mathbf{w} \sim \mathcal{N}(\mathbf{0}, I). \mathbf{w} \leftarrow \frac{\mathbf{w}}{\lVert \mathbf{w} \rVert_2}\)
초기 상태 관측: \(s_0 \sim d_0(s)\)
매 timestep마다 다음을 진행
상태 인코딩: \(z_t \leftarrow \phi(s_t)\)
행동 결정: \(a_t \sim \pi_\theta(\cdot| z_t, \mathbf{w})\)
행동 수행 및 다음 상태 관측
Replay buffer에 저장: \(\mathcal{D} \leftarrow \mathcal{D} \cup \left\{ (s_t, a_t, s_{t+1}, \mathbf{w}) \right\}\)
네트워크 업데이트 (hat 심볼은 stop-gradient)
\(\left\{ (s_i, a_i, s_{i+1}, \mathbf{w}_i) \right\}_{i=1}^{N} \sim \mathcal{D}\)
\(z_i \leftarrow \phi(s_i), z_{i+1} \leftarrow \phi(s_{i+1})\)
\(r_i \leftarrow z_i^\top \mathbf{w}_i\)
\(a' \sim \pi_\theta(\cdot|z_{i+1}, \mathbf{w}_i)\)
\(y_i \leftarrow r_i + \gamma \left( \psi(\hat{z}_{i+1}, a', \mathbf{w}_i)^\top \mathbf{w}_i - \alpha \log \pi_\theta(a' |z_{i+1},\mathbf{w}_i)\right)\) \(\#\) Double Q-learning도 가능, SAC의 soft Q-target
\(\text{loss}_\psi = \sum_i \left(\psi(\hat{z}_i, a_i, \mathbf{w}_i)^\top \mathbf{w}_i - \hat{y}_i \right)^2\)
\(\text{loss}_\phi = - z_i^{\top} \mathbf{w}_i\)
\(a \leftarrow \pi_\theta(\cdot|z_{i}, \mathbf{w}_i)\)
\(\text{loss}_\theta = \sum_i \left( \psi(\hat{z}_i, a, \mathbf{w}_i)^\top\mathbf{w}_i - \alpha \log \pi_\theta (a | \hat{z}_i, \mathbf{w}_i )\right)\) \(\#\) SAC의 reparametrization objective
\(\pi_\theta, \psi, \pi\) 네트워크 업데이트