diff --git a/sam/onyx/generate_matrices.py b/sam/onyx/generate_matrices.py index 7d2fb118..956bb44b 100644 --- a/sam/onyx/generate_matrices.py +++ b/sam/onyx/generate_matrices.py @@ -43,8 +43,8 @@ def __init__(self, name='B', shape=None, sparsity=0.6, format='CSF', dump_dir=No self.dump_dir = tempfile.gettempdir() if tensor is not None: - if not tensor.dtype == numpy.float32: - self.array = tensor + if not self.use_fp: + self.array = tensor.astype(numpy.uint16, casting='unsafe') self.shape = self.array.shape else: self.array = tensor @@ -490,7 +490,7 @@ def create_matrix_from_point_list(name, pt_list, shape, use_fp=False, base=16) - else: mat_base = mat_base.astype(numpy.uint16, casting='unsafe') - mg = MatrixGenerator(name=f"{name}", shape=shape, sparsity=0.7, format='CSF', dump_dir=None, tensor=mat_base) + mg = MatrixGenerator(name=f"{name}", shape=shape, sparsity=0.7, format='CSF', dump_dir=None, tensor=mat_base, use_fp=use_fp) return mg