Skip to content

Commit

Permalink
Merge pull request #98 from AlessandroFlati/develop
Browse files Browse the repository at this point in the history
This will solve CPU-only, CUDA-only and any mix of them.
  • Loading branch information
KindXiaoming authored May 7, 2024
2 parents 116f399 + c857dd6 commit 70b7b8d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion kan/KAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def initialize_from_another_model(self, another_model, x):
# spb = spb_parent
preacts = another_model.spline_preacts[l]
postsplines = another_model.spline_postsplines[l]
self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch, spb.size).permute(1, 0), postsplines.reshape(batch, spb.size).permute(1, 0), spb.grid, k=spb.k)
self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch, spb.size).permute(1, 0), postsplines.reshape(batch, spb.size).permute(1, 0), spb.grid, k=spb.k, device=self.device)
spb.scale_base.data = spb_parent.scale_base.data
spb.scale_sp.data = spb_parent.scale_sp.data
spb.mask.data = spb_parent.mask.data
Expand Down
6 changes: 3 additions & 3 deletions kan/KANLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=
if isinstance(scale_base, float):
self.scale_base = torch.nn.Parameter(torch.ones(size, device=device) * scale_base).requires_grad_(sb_trainable) # make scale trainable
else:
self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).cuda()).requires_grad_(sb_trainable)
self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).to(device)).requires_grad_(sb_trainable)
self.scale_sp = torch.nn.Parameter(torch.ones(size, device=device) * scale_sp).requires_grad_(sp_trainable) # make scale trainable
self.base_fun = base_fun

Expand Down Expand Up @@ -249,8 +249,8 @@ def initialize_grid_from_parent(self, parent, x):
# preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_pos = parent.grid
sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0.).to(self.device)
sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1)
sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0., device=self.device)
sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1, device=self.device)
y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k, device=self.device)
percentile = torch.linspace(-1, 1, self.num + 1).to(self.device)
self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0)
Expand Down

0 comments on commit 70b7b8d

Please sign in to comment.