Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SGD NaN Fix #43

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions neighbors/_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def sgd(
error_history = np.zeros((n_iterations))
converged = False
last_e = 0
e = 0
error_is_nan = False
norm_rmse = np.inf
delta = np.inf
np.random.seed(seed)
Expand Down Expand Up @@ -69,6 +71,11 @@ def sgd(

# Use changes in e to determine tolerance
e = data[u, i] - prediction # error
# Check if predictions have exploded resulting in NaN errors
# and prevent propagation but breaking early
if np.isnan(e):
error_is_nan = True
break

# Update biases
user_bias[u] += learning_rate * (e - user_bias_reg * user_bias[u])
Expand All @@ -85,6 +92,11 @@ def sgd(
# Keep track of total squared error
total_error += np.power(e, 2)

# Check if error was nan
if error_is_nan:
converged = False
break

# Force non-negativity. Surprise does this per-epoch via re-initialization. We do this per sweep over all training data, e.g. see: https://github.com/NicolasHug/Surprise/blob/master/surprise/prediction_algorithms/matrix_factorization.pyx#L671
user_vecs = np.maximum(user_vecs, 0)
item_vecs = np.maximum(item_vecs, 0)
Expand All @@ -107,6 +119,7 @@ def sgd(
return (
error_history,
converged,
error_is_nan,
this_iter,
delta,
norm_rmse,
Expand Down
5 changes: 5 additions & 0 deletions neighbors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def fit(
(
error_history,
converged,
error_is_nan,
n_iter,
delta,
norm_rmse,
Expand Down Expand Up @@ -569,11 +570,15 @@ def fit(
self._delta = delta
self._norm_rmse = norm_rmse
self.converged = converged
self.error_is_nan = error_is_nan
if verbose:
if self.converged:
print("\n\tCONVERGED!")
print(f"\n\tFinal Iteration: {self._n_iter}")
print(f"\tFinal Delta: {np.round(self._delta)}")
elif self.error_is_nan:
print("\tFAILED TO CONVERGE (predictions are NaN)")
print(f"\n\tFinal Iteration: {self._n_iter}")
else:
print("\tFAILED TO CONVERGE (n_iter reached)")
print(f"\n\tFinal Iteration: {self._n_iter}")
Expand Down
Loading