Skip to content

Commit

Permalink
further changes to comply with diffxpy api
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsebfischer committed Aug 21, 2019
1 parent e746783 commit 97c4407
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
8 changes: 5 additions & 3 deletions batchglm/models/base/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ def feature_isallzero(self):
return self._feature_allzero

def fetch_x_dense(self, idx):
return self.x[idx]
assert isinstance(self.x, np.ndarray), "tried to fetch dense from non ndarray"

return self.x[idx, :]

def fetch_x_sparse(self, idx):
assert isinstance(self.x, scipy.sparse.csr_matrix), "tried to fetch sparse from non csr matrix"
assert isinstance(self.x, scipy.sparse.csr_matrix), "tried to fetch sparse from non csr_matrix"

data = self.x[idx]
data = self.x[idx, :]

data_idx = np.asarray(np.vstack(data.nonzero()).T, np.int64)
data_val = np.asarray(data.data, np.float64)
Expand Down
4 changes: 2 additions & 2 deletions batchglm/models/base_glm/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ def num_scale_params(self):
return self.constraints_scale.shape[1]

def fetch_design_loc(self, idx):
return self.design_loc[idx]
return self.design_loc[idx, :]

def fetch_design_scale(self, idx):
return self.design_scale[idx]
return self.design_scale[idx, :]

def fetch_size_factors(self, idx):
return self.size_factors[idx]
2 changes: 1 addition & 1 deletion batchglm/train/tf/base_glm_all/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def fetch_fn(idx):
num_design_scale_params=input_data.num_design_scale_params,
num_loc_params=input_data.num_loc_params,
num_scale_params=input_data.num_scale_params,
batch_size=batch_size,
batch_size=np.min([batch_size, input_data.x.shape[0]]),
graph=graph,
init_a=init_a,
init_b=init_b,
Expand Down
2 changes: 1 addition & 1 deletion batchglm/unit_test/test_graph_glm_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
else:
raise ValueError("noise_model not recognized")

batch_size = 100
batch_size = 200
provide_optimizers = {
"gd": False, "adam": False, "adagrad": False, "rmsprop": False,
"nr": False, "nr_tr": False,
Expand Down

0 comments on commit 97c4407

Please sign in to comment.