HIerarchical Reinforcement learning with Off-policy correction (HIRO)
Key idea: FuN + TD3 + off-policy correction
•
FuN의 한계점
1.
FuN은 worker와 manager가 모두 현재 policy로만 데이터를 수집하는 방식인 on-policy 이다. 따라서 수집 효율이 낮고, sample efficiency가 매우 떨어진다.
2.
manager가 설정하는 subgoal 은 latent space의 vector로 이 subgoal 이 실제 action과 얼마나 의미 있게 연결되는지 해석이 어렵다. reward도 subgoal 의 cosine similarity 기반으로 추정됨 → 불안정할 가능성이 있다.
3.
manager와 worker가 서로 다른 time step 단위로 update되면서 non-stationarity 발생한다. reward assignment가 어렵고, worker가 subgoal 을 잘 따라가지 못하면 manager의 학습도 실패할 수 있다.
Off-policy correction
Off-policy learning : behavior policy로 수집된 data를 사용해서 target policy를 학습하는 방법.
off-policy learning의 문제점 : behavior policy로부터 수집된 데이터가 target policy에서 나올 만한 data가 아닐 수도 있다 → gradient가 잘못된 방향으로 흘러서 학습이 불안정해진다. 이 둘이 다르기 때문에 기댓값 계산이 틀어진다.
그러므로 correction을 해줘야 올바른 방향으로 policy를 update할 수 있다.
•
어떻게 correction을 하는가?
1.
Importance Sampling (IS)
어떤 분포에서 기대값을 구하고 싶지만, 그 분포에서 직접 sampling 할 수 없을 때, 다른 분포로부터 sampling한 데이터를 이용해 원하는 기대값을 보정해서 계산한다.
2.
Clipping or truncation
importance weight를 clipping하거나 truncated Importance sampling을 사용한다.
예: V-trace, Q(λ), GAE 등 → 안정성을 높이며, bias/variance trade-off 를 조정한다.
•
HER에서의 off-policy correction
HER도 off-policy 방식이기 때문에, hindsight goal이 target policy과는 다를 수 있다. 이때에는 correction 기법(예: importance sampling)은 안 쓰는 대신, goal relabeling을 통해 적절한 보상을 주는 것으로 암묵적으로 correction을 수행한다.
HIRO는 FuN처럼 2개의 hierarchy를 사용하는 hierarchical policy architecture 이지만, Off-policy Reinforcement Learning 알고리즘인 TD3 의 매커니즘을 사용하여 sample efficiency와 학습 안정성을 개선하였다.
가 아니라 로 작성한 이유
TD3는 deterministic policy 기반이기 때문. 는 stochastic policy일 때 쓴다.
•
Low-level policy () - worker
Input:
Output:
low-level policy는 현재 상태 와 high-level policy가 제시한 goal 를 입력으로 받아 action 를 선택하고, 이 action은 환경과 직접 상호작용한다. reward 는 이 행동의 결과로서 goal 를 얼마나 잘 따라갔는지를 측정하며, 그 기반으로 계산된다.
Actor-Critic 구조를 사용하며, 이때 critic은 , actor는 로 TD3 기반 학습을 수행한다.
•
High-level policy () - manager
일정 주기마다 (예: steps마다) goal 을 생성한다. 여기서 goal 은 state space 안의 vector 혹은 target state를 의미한다.
reward는 환경에서 받는 실제 reward의 누적값을 의미한다. 단, worker 가 goal 가 아닌 이전 goal 기준으로 행동했을 수 있다. 그래서 goal mismatch로 인한 off-policy 문제 발생한다.
따라서 goal relabeling을 통해 high-level policy도 off-policy로 학습 가능하게 한다.
•
subgoal transition function
Low-level policy 를 위한 intrinsic reward 는 goal state 와의 distance 로 계산된다. goal state 를 현재 state 에서의 relative position 으로 다루기 때문에 state 가 바뀌면 goal state 도 따라 바뀌게 되고, 이를 위해 goal transition function 를 사용한다:
직관적으로 보면 "goal 대로 갔으면 이 와 같았을 것" 과 같다.
Off-policy correction in HIRO
•
(A) Low-level policy correction
low-level policy goal 에 조건화된 policy이다. 따라서 기존 replay buffer에 있는 transition들을 학습할 때, 현재 정책과 mismatch되는 문제가 발생한다. 이를 보완하기 위해 “goal-conditioned reward” 를 사용한다. goal 방향과 실제 이동이 유사할수록 값이 커진다 (작은 오차 → 큰 reward)
goal-conditioned reward
Worker는 Manager가 지정한 goal 를 달성하려는 학습을 하는데, reward가 sparse한 환경에서는 Worker는 reward를 전혀 받을 수 없는 문제가 발생한다. 이런 문제를 해결하기 위해 Worker가 현재 state 와 Manager의 goal 사이의 거리를 줄일 수록 비례해서 intrinsic reward를 준다.
•
(B) High-level policy correction: goal relabeling
high-level policy는 과거 low-level action들의 결과를 보고, 어떤 goal을 주면 그런 action들이 나왔을지를 추정해야 한다. low-level policy는 매 step마다 update하는데, high-level policy은 step마다 update하기 때문이다. replay buffer의 transition 에 대해서 현재의 low-level policy가 실제로 이 행동을 생성했을 것처럼 보이도록 goal을 relabel 한다.
그래서 past low-level actions을 최대화할 수 있도록 실제 goal 대신 relabel된 goal 를 사용해서 학습한다. 이것을 goal relabeling이라고 부른다.
goal을 주었을 때 현재 low-level policy가 replay buffer에 저장되어 있던 것과 같은 행동을 하도록 goal 를 relabeling하여 찾아준다.
•
Step 1 : Collect experience
환경에서 상호작용하며 다음 데이터를 수집한다. (현재 state , high-level policy가 생성한 subgoal , low-level policy 가 수행한 action , 환경에서 받은 reward ) 이들은 replay buffer에 저장된다.
•
Step 2: Low-level policy 학습
low-level policy는 subgoal을 목표로 action을 수행하도록 학습된다. 여기서 는 state 와 함께 입력되고, reward는 subgoal과 실제 이동 간의 차이로 정의된다. (L2 distance) off-policy actor-critic (TD3) 구조로 학습한다.
•
Step 3: High-level policy 학습
high-level policy는 -step마다 한 번 학습되고, 전체 구간 부터 까지 누적 reward를 기반으로 학습한다. 여기서 사용되는 goal 는 goal relabeling 을 통해 얻은 것으로 “이 행동들이 현재 low-level policy 기준으로 나오도록 만들 수 있는 가장 적절한 subgoal은 무엇인가?”를 역으로 추정하는 과정이다.
•
Step 4: Repeat
위의 과정을 반복한다.