ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Neural ODE 논문 리뷰 [18' NeurIPS] Neural Ordinary Differential Equation(Chen et al.)
    논문 리뷰/Generative Model 2024. 7. 15. 20:39

    몇 번을 봐도 이해가 어렵던 Neural ODE를 드디어 정리해보겠다.

    증명이 이해가 안되면 다음 줄로 못 넘어가는 성격인지라 논문의 내용 + Appendix에 초점을 맞춰 정리해보고자 한다.

     

    Neural ODE는 RNN과 Normalizing flow의 식이 일종의 ODE처럼 생겼다는 점에서 착안하여, neural net을 원하는 함수의 미분값으로 보고 학습한다. 우리에게 잘 알려진 Euler method나 Runge Kutta method를 통해 ODE를 풀고(이게 논문에서 말하는 ODEsolver이다.) Backpropagation을 위해 adjoint sensitivity method를 가져왔다. 

     

    우선 Appendix를 잘 정리해두면 이해가 쉬워서 Appendix를 먼저 살펴보자.

     

    우선 instantaneous change of variable theorem이 나온다. 증명이 논문에 자세히 나와있어 정리는 생략한다.

     

    그리고 아래 adjoint state에 대한 미분식 증명도 나온다. 

    마지막으로 아래 식까지 가지고 논문을 이해하러 가보자.

     

    Introduction

    Figure 1은 RNN과 NODE를 비교한 그림으로 step size가 결정된 RNN과 달리 NODE는 evaluation point가 flexible하다.

    RNN이나 Normalizing flow에서는 여러 층의 hidden state을 거쳐 우리가 원하는 결과가 나오게끔 학습시켰다. 이 모양은 Euler discretization의 관점에서 이해해볼 수 있는데, 만약 우리가 은닉 층을 더 추가하고 step의 크기를 줄인다면 (2)번 식과 같은 ODE(미분방정식)를 얻을 수 있다.

     

    여기서 $f$는 the derivative of the hidden state으로 neural network로 표현된다.

     

    Reverse-mode automatic differentiation of ODE solutions

    이 논문에서는 forward 방향은 기존의 ODE Solver를 이용한다. 여기서 주목할 점은 사용한 ODE Solver는 일종의 Black box로 취급하여 그 결과값만을 고려한다는 점이다. backward 방향은 adjoint sensitivity method를 사용하는데, 이 방법을 사용하면 memory cost가 낮다는 장점이 있다.

    우선 $[t_0, t_1]$ 구간을 생각해보자. 우리의 loss function은 다음과 같이 주어진다.

    loss function을 optimize하기 위해서는 $ \frac{dL}{d\theta}$를 알아야 한다. 

     

    우선 adjoint를 정의하자.

    Appendix에서 이미 봤던대로 아래와 같은 식 (4)를 얻을 수 있다.

    $a(t_0)$는 ODE Solver를 통해 backward로 ($t_1$ 에서 $t_0$로) 구할 수 있다는 점을 기억하자.

    $\frac{dL}{d\theta}$는 Appendix에서 봤듯이 아래와 같이 표현 가능하다.

    여기서 vector-Jacobian product인 $a(t)^T \pd{f}{z}$, $a(t)^T \pd{f}{\theta}$은 automatic differentiation을 통해서 쉽게 계산할 수 있다. 또힌 $z$, $a$, $\pd{L}{\theta}$를 구하기 위해서는 이 세 개의 벡터를 하나의 벡터로 합해서 단일 ODE Solver를 통해 계산하면 된다.

     

     

    Continuous Normalizing Flows

    Normalizing flow에서는 chage of varaible theorem을 이용해 probability를 계산할 수 있다는 장점을 가지고 있으나, 모델에 따라 복잡한 Jacobian determinant 계산을 해야 한다는 단점이 있다.(이느 Normalizing Flow에서 일종의 bottleneck으로 꼽힌다.)

     

    이러한 단점은 discrete한 layer가 아닌 continuous transformation에서 단순화 될 수 있다.(아래 식 또한 Appendix에서 이미 본 식!)

    원래 Jacobian determinant였던 term이 trace로 바뀌면서 계산이 간단해졌다는 점을 눈여겨 보아야 한다.

     

    또한 $f(z(t), t)$이 $t$에 따라 변하도록 하고 gating mechanism을 도입하므로써 시간에 따른 연속적인 변환 CNF를 고안하였다.

     

    planar flow에서 NF와 CNF를 비교한 실험은 아래와 같다.

    CNF의 결과가 좀 더 target data에 가까우며 loss 값 또한 더 작다.

     

    A generative latent function time-series model

    기본 세팅

    generative model로써의 Neural ODE를 보여주기 위해 이 논문에서는 time-series를 모델링해보았다. Neural ODE는 latent trajectory를 나타낼 수 있으며, 각각의 trajectory는 초기값 $z_{t_0}$와 $f$(=global latent dynamics)에 의해 결정된다.

    여기서 함수 $f$는 time-invariant하기 때문에 $t$가 아닌 $z(t)$에만 의존하며, 이로 인해 모든 $t$에 대해 동일한 $f$, 즉 동일한 global latent dynamic이 작용한다고 볼 수 있다.

    이 논문에서는 latent-variable model을 VAE로 상정하고 encoder는 RNN, decoder는 Neural ODE로 구성하였다. ODE를 generative model로 prediction을 진행하였으며, extrapolation도 실험해보았다. 

    RNN에 비해 Neural ODE가 훨씬 더 잘 예측하는 것을 알 수 있다.

    Neural ODE의 장점

    1. Memory efficiency

    Neural ODE는 adjoint sensitivity method를 사용하기 때문에 중간 단계를 저장하지 않고도 그래디언트를 계산할 수 있다.

    →train models with constant memory cost

     

    2. Adaptive computation

    ODE Solver는 지속적으로 발전하고 있으며, 정확한 계산을 가능하게 하는 방법론이 많이 제시되었다.

     

    3. Scalable and invertible normalizing flow

    instantaneous change of variables를 통해 chage of variable 식이 더욱 간단해졌다.

     

    4. Continuous time-series models

    동일한 시간 간격이 아닌 arbitrary times에 대한 data도 다룰 수 있다.

     

    Code

    https://github.com/rtqichen/torchdiffeq/blob/master/examples/ode_demo.py

     

    torchdiffeq/examples/ode_demo.py at master · rtqichen/torchdiffeq

    Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. - rtqichen/torchdiffeq

    github.com

    Demo file은 spiral ODE에 대한 구현이 담겨있다. 이 파일을 간단하게 정리해보았다.

    import os
    import time
    import numpy as np
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import matplotlib.pyplot as plt

     

    우선 필요한 라이브러리를 먼저 불러온다.

     

    method = 'dopri5' # ODE솔버로 Runge-Kutta 방법의 일종
    data_size = 1000  # data point의 수
    batch_time = 10   # 각 배치의 시간 길이 = 10
    batch_size = 20   # 각 배치의 크기 = 20
    niters = 2000     # 학습반복 횟수 = 2000
    test_freq = 20    # 테스트 주기 = 20
    adjoint = False   # Adjoint Sensitivity Method 사용 여부
    gpu = 0           # 사용할 gpu 장치 번호
    if adjoint : 
        from torchdiffeq import odeint_adjoint as odeint
    else:
        from torchdiffeq import odeint
        
    device = torch.device('cuda:'+str(gpu) if torch.cuda.is_available() else 'cpu')

    adjoint가 True이면 backpropagation을 adjoint sensitivity method 실행한다.

    true_y0 = torch.tensor([[2., 0.]]).to(device)
    t = torch.linspace(0., 25., data_size).to(device)
    true_A = torch.tensor([[-0.1, 2.0],[-2.0, -0.1]]).to(device)
    
    class Lambda(nn.Module):
        def forward(self, t, y):
            return torch.mm(y**3, true_A)
            
    with torch.no_grad():
    	true_y = odeint(Lambda(), true_y0, t, method = 'dopri5')

    $y_0 = (2, 0)$을 초기 점으로 하며 $\frac{dy(t)}{dt} = y^{3} \dot A$로 생성된 데이터를 생각해보자.

     

    odeint(func, y_0, t)

    - func : input이 (t, y)이고 output이 dy/dt인 미분방정식을 정의하는 함수

    - y_0 : 초기 상태

    - t : 시간 지점들의 배열

    def get_batch():
        s = torch.from_numpy(np.random.choice(np.arange(data_size - batch_time, dtype= np.int64), batch_size, replace = False))
        # np.arange(~) : 0부터 data_size - batch_time까지의 정수 배열을 생성 : 가능한 시작 인덱스를 나타냄
        # np.random.choice : 가능한 시작 인덱스 중에서 batch_size개를 랜덤하게 선택
        # replace_False : 중복을 허용하지 않음
        batch_y0 = true_y[s]           # (batch_size, 2)
        batch_t = t[:batch_time]       # (batch_time, )
                                       # (batch_time, batch_size, 2) 
        batch_y = torch.stack([true_y[s + i] for i in range(batch_time)], dim = 0)
        return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

    get_batch 함수의 역할
    1. 배치 데이터의 생성 : 전체 데이터 중에서 랜덤하게 일부 데이터를 선택하여 배치를 만든다
    2. 배치 초기 상태 및 목표값 반환 : 선택된 배치의 초기 상태와 목표 값을 반환한다. 이를 통해 모델이 학습할 수 있다.

    def visualize(true_y, pred_y, func):
        plt.figure(figsize = (12, 4))
        
        # 131 
        # 첫 번째 자리 : 열의 수
        # 두 번째 자리 : 행의 수
        # 세 번째 자리 : 현재 플롯의 위치
        plt.subplot(131)
        plt.title('Trajectories')
        plt.xlabel('t')
        plt.ylabel('x, y')
        plt.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], 'g-', label = 'True y0')
        plt.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'b-', label = 'True y1')    
        plt.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], 'r--', label = 'Pred y0')    
        plt.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'y--', label = 'Pred y1')
        plt.legend()
        
        plt.subplot(132)
        plt.title('Phrase Portrait')
        plt.xlabel('x')
        plt.ylabel('y')
        plt.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
        plt.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b-')
        
        plt.subplot(133)
        plt.title('Learned Vector Field')
        plt.xlabel('x')
        plt.ylabel('y')
        
        #mgrid : 다차원 격자 점을 생성하는 도구
        # 등간격으로 분포된 값을 나타내는데 유용
        # 21j는 주어진 범위 내에서 21개의 등간격 포인트를 생성
        # x, y는 각각 21 * 21 크기의 2차원 배열
        # x : 각 행이 동일한 값 / y : 각 열이 동일한 값
        y, x = np.mgrid[-2:2:21j, -2:2:21j]
        
        #[x, y] : (2, 21, 21)
        # np.stack함수는 주어진 배열들을 새로운 축(여기서는 마지막 축)을 따라 쌓는다
        # np.stack([x, y], -1) : (21, 21, 2)
        # reshape 해서 (21*21,2) -> cpu 이동/ 그래프에서 분리 / numpy배열로 변환
        dydt = func(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
        
        #magnitude 
        #dydt[:, 0] ** 2 : 각 벡터의 x성분의 제곱
        #dydt[:, 1] ** 2 : 각 벡터의 y성분의 제곱
        #reshape(-1, 1) : 결과 배열을 (21*21, 1)크기로 변환
        mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
        # normalzie -> 모든 벡터 크기 1
        dydt = (dydt / mag)
        dydt = dydt.reshape(21, 21, 2)
        
        #streamplot : 벡터 필드를 시각화
        plt.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color = "black")
        plt.xlim(-2, 2)
        plt.ylim(-2, 2)
        
        plt.tight_layout()
        plt.show()

    그림을 그려주는 visualize도 정의해주자.

    class ODEFunc(nn.Module):
        def __init__(self):
            super(ODEFunc, self).__init__()
            
            self.net = nn.Sequential(
                nn.Linear(2, 50), 
                nn.Tanh(), 
                nn.Linear(50, 2),
            )
            
            for m in self.net.modules():
                if isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, mean =0, std = 0.1)
                    nn.init.constant_(m.bias, val = 0)
        def forward(self, t, y):
            return self.net(y**3)

    ODEFunc 클래스를 학습하여 신경망이 주어진 데이터에 맞춰 상태 변화를 예측하도록 한다.

     

    이때 y**3을 해주는 이유는 입력 데이터의 변형을 통해 신경망이 더 복잡한 비선형 관계를 학습할 수 있도록 도와주기 때문이다. 원래 데이터의 비선형성을 증가시켜 신경망이 더 복잡한 패턴을 학습할 수 있도록 도우며, 이는 모델의 성능을 향상시킬 수 있지만, 특정한 이유 없이 사용되기도 한다.

    class RunningAverageMeter(object):
        def __init__(self, momentum = 0.99):
            self.momentum = momentum
            self.reset()
        def reset(self):
            self.val = None
            self.avg = 0
        def update(self, val):
            if self.val is None:
                self.avg = val
            else:
                self.avg = self.avg * self.momentum + val * (1 - self.momentum)
            self.val = val
    if __name__ == '__main__':
        ii = 0
        
        func = ODEFunc().to(device)
        optimizer = optim.RMSprop(func.parameters(), lr = 1e-3)
        end = time.time()
        
        time_meter = RunningAverageMeter(0.97)
        loss_meter = RunningAverageMeter(0.97)
        
        for itr in range(1, niters + 1):
            optimizer.zero_grad()
            batch_y0, batch_t, batch_y = get_batch()
            pred_y = odeint(func, batch_y0, batch_t).to(device)
            loss = torch.mean(torch.abs(pred_y - batch_y))
            loss.backward()
            optimizer.step()
            
            time_meter.update(time.time()-end)
            loss_meter.update(loss.item())
            
            if itr % test_freq == 0:
                with torch.no_grad():
                    pred_y = odeint(func, true_y0, t)
                    loss = torch.mean(torch.abs(pred_y - true_y))
                    print(f'Iter {itr:04d} | Total loss {loss.item():.6f}')
                    visualize(true_y, pred_y, func)
            end = time.time()

     

    iter = 20, loss = 0.66
    iter = 70, loss = 0.16

    댓글

Designed by Tistory.