ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [1] pykan 라이브러리 사용법 : Hello KAN!
    PROGRAMMING/python 2024. 10. 13. 15:20

    Kolmogorov-Arnold Network(KAN)을 가장 쉽게 사용할 수 있는 방법! pykan 라이브러리 사용법에 대해 정리해보고자 한다. 

     

    ※ 모든 내용은 https://kindxiaoming.github.io/pykan/intro.html 의 내용을 번역한 것으로 모든 저작권은 해당 페이지에 있습니다. 부정확한 표현이 있을 수 있으니 원문 페이지를 꼭 참고해주세요!

     

    https://kindxiaoming.github.io/pykan/kan.html 

     

    Kolmogorov-Arnold Representation Theorem

    Kolmogorov-Arnold Representation Theorem에 따르면 다변수 연속함수 $f$는 bounded domain에서 단일 변수의 연속함수와 이항 연산인 덧셈의 유한 합성 꼴로 나타낼 수 있다. 즉, 아래와 같은 표현이 가능하다.

    이 정리는 모든 다변수 함수는 덧셈이라고 할 수 있으며, 단일 변수 함수와 덧셈을 통해 표현될 수 있음을 보여준다. 이 정리에서는 2층의 폭이 n인 Kolmogorov-Arnold Representation으로 표현이 한정되어 있어, KAN 원본 논문에서는 이를 일반화하여 확장한다. 

     

    Kolmogorov-Arnold Network(KAN)

    Kolmogorov-Arnold Representation은 아래와 같은 matrix의 형태로 고쳐쓸 수 있다.

    여기서 Kolmogorov-Arnold layer는 다음과 같이 정의 가능하다.

    Kolmogorov-Arnold Network(이하 KAN)은 Kolmogorov-Arnold layer(이하 KA layer)를 쌓아 만들 수 있으며, 우리는 앞으로 $L$개의 레이어에 대해 $l$번째 레이어의 $\boldsymbol{\Phi}_l$가 $(n_{l+1}, n_{l})$의 모양을 가지도록 한다.

    반면에 MLP(Multi-Layer Perceptron)은 아래와 같이 쓸 수 있다.

     

    Get started with KANs

    pip install pykan

    우선 pykan을 설치해보자

    from kan import *
    model = KAN(width = [2, 5, 1], grid = 5, k = 3, seed = 0)
    f = lambda x : torch.exp(torch.sin(torch.pi * x[:, [0]]) + x[:, [1]] ** 2)
    dataset = create_dataset(f, n_var=2)
    dataset['train_input'].shape, dataset['train_label'].shape
    >> 
    (torch.Size([1000, 2]), torch.Size([1000, 1]))

    model을 만들고, 생성한 함수 $f$에 맞는 dataset을 만들자.

    model(dataset['train_input'])
    model.plot(beta = 100)

    # train the model
    model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);

    * model.train은 더 이상 작동하지 않는다. model.fit으로 학습시키자

    fitting 후 다시 model을 plot해보자. 

    model.prune()
    model.plot()

    model을 prunning 한 다음 다시 그려보자(prunning이 제대로 안 된 것 같지만, 이 부분은 다음에 확인해보는걸로,,

    model = model.prune()
    model(dataset['train_input'])
    model.plot()

    model.fit(dataset, opt="LBFGS", steps=100);
    model.prune()
    model.plot()

    mode = "auto"
    
    if mode == "manual":
      model.fix_symbolic(0, 0, 0, 'sin')
      model.fix_symbolic(0, 1, 0, 'x^2')
      model.fix_sumbolic(1, 0, 0, 'exp')
    elif mode == "auto":
      lib = [ 'x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
      model.auto_symbolic(lib = lib)
      >>
      fixing (0,0,0) with 0
    fixing (0,0,1) with sin, r2=0.9999524354934692, c=2
    fixing (0,1,0) with 0
    fixing (0,1,1) with x^2, r2=0.9999854564666748, c=2
    fixing (1,0,0) with 0
    fixing (1,1,0) with exp, r2=0.999994695186615, c=2
    model.symbolic_formula()[0][0]

    댓글

Designed by Tistory.