diff --git a/kan/spline.py b/kan/spline.py index 1be097f1..6f64d883 100644 --- a/kan/spline.py +++ b/kan/spline.py @@ -97,6 +97,8 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"): ''' # x_eval: (size, batch), grid: (size, grid), coef: (size, coef) # coef: (size, coef), B_batch: (size, coef, batch), summer over coef + if coef.dtype != x_eval.dtype: + coef = coef.to(x_eval.dtype) y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device)) return y_eval