-
[1] pykan 라이브러리 사용법 : Hello KAN!PROGRAMMING/python 2024. 10. 13. 15:20
Kolmogorov-Arnold Network(KAN)을 가장 쉽게 사용할 수 있는 방법! pykan 라이브러리 사용법에 대해 정리해보고자 한다.
※ 모든 내용은 https://kindxiaoming.github.io/pykan/intro.html 의 내용을 번역한 것으로 모든 저작권은 해당 페이지에 있습니다. 부정확한 표현이 있을 수 있으니 원문 페이지를 꼭 참고해주세요!
https://kindxiaoming.github.io/pykan/kan.html
Kolmogorov-Arnold Representation Theorem
Kolmogorov-Arnold Representation Theorem에 따르면 다변수 연속함수 $f$는 bounded domain에서 단일 변수의 연속함수와 이항 연산인 덧셈의 유한 합성 꼴로 나타낼 수 있다. 즉, 아래와 같은 표현이 가능하다.
이 정리는 모든 다변수 함수는 덧셈이라고 할 수 있으며, 단일 변수 함수와 덧셈을 통해 표현될 수 있음을 보여준다. 이 정리에서는 2층의 폭이 n인 Kolmogorov-Arnold Representation으로 표현이 한정되어 있어, KAN 원본 논문에서는 이를 일반화하여 확장한다.
Kolmogorov-Arnold Network(KAN)
Kolmogorov-Arnold Representation은 아래와 같은 matrix의 형태로 고쳐쓸 수 있다.
여기서 Kolmogorov-Arnold layer는 다음과 같이 정의 가능하다.
Kolmogorov-Arnold Network(이하 KAN)은 Kolmogorov-Arnold layer(이하 KA layer)를 쌓아 만들 수 있으며, 우리는 앞으로 $L$개의 레이어에 대해 $l$번째 레이어의 $\boldsymbol{\Phi}_l$가 $(n_{l+1}, n_{l})$의 모양을 가지도록 한다.
반면에 MLP(Multi-Layer Perceptron)은 아래와 같이 쓸 수 있다.
Get started with KANs
pip install pykan
우선 pykan을 설치해보자
from kan import * model = KAN(width = [2, 5, 1], grid = 5, k = 3, seed = 0) f = lambda x : torch.exp(torch.sin(torch.pi * x[:, [0]]) + x[:, [1]] ** 2) dataset = create_dataset(f, n_var=2) dataset['train_input'].shape, dataset['train_label'].shape >> (torch.Size([1000, 2]), torch.Size([1000, 1]))
model을 만들고, 생성한 함수 $f$에 맞는 dataset을 만들자.
model(dataset['train_input']) model.plot(beta = 100)
# train the model model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);
* model.train은 더 이상 작동하지 않는다. model.fit으로 학습시키자
fitting 후 다시 model을 plot해보자.
model.prune() model.plot()
model을 prunning 한 다음 다시 그려보자(prunning이 제대로 안 된 것 같지만, 이 부분은 다음에 확인해보는걸로,,
model = model.prune() model(dataset['train_input']) model.plot()
model.fit(dataset, opt="LBFGS", steps=100); model.prune() model.plot()
mode = "auto" if mode == "manual": model.fix_symbolic(0, 0, 0, 'sin') model.fix_symbolic(0, 1, 0, 'x^2') model.fix_sumbolic(1, 0, 0, 'exp') elif mode == "auto": lib = [ 'x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs'] model.auto_symbolic(lib = lib) >> fixing (0,0,0) with 0 fixing (0,0,1) with sin, r2=0.9999524354934692, c=2 fixing (0,1,0) with 0 fixing (0,1,1) with x^2, r2=0.9999854564666748, c=2 fixing (1,0,0) with 0 fixing (1,1,0) with exp, r2=0.999994695186615, c=2
model.symbolic_formula()[0][0]
'PROGRAMMING > python' 카테고리의 다른 글
[GluonTS 3탄] GluonTS로 시계열 데이터 예측해보기 - 심화편 (0) 2024.09.20 [GluonTS 2탄] GluonTS로 시계열 데이터 예측해보기 (1) 2024.09.15 [GluonTS 1탄] GluonTS란 무엇일까? (0) 2024.09.15 [PyTorch] gather 함수 설명(slicing 없이 특정 인덱스만 추출하기) (0) 2024.09.11