Skip to content

Commit

Permalink
Merge pull request #129 from Jim137/develop
Browse files Browse the repository at this point in the history
Fix dtype error when input dtype is Double
  • Loading branch information
KindXiaoming authored May 8, 2024
2 parents fcefaba + d4305a3 commit a666ec3
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions kan/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"):
'''
# x_eval: (size, batch), grid: (size, grid), coef: (size, coef)
# coef: (size, coef), B_batch: (size, coef, batch), summer over coef
if coef.dtype != x_eval.dtype:
coef = coef.to(x_eval.dtype)
y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
return y_eval

Expand Down

0 comments on commit a666ec3

Please sign in to comment.