From 07f183ccd8a47ec91516a6f0200864ea7bd7b7ec Mon Sep 17 00:00:00 2001 From: Antony Chan Date: Sat, 5 Oct 2024 13:36:23 -0700 Subject: [PATCH] Resolve linter error E721 --- proximal/algorithms/absorb.py | 2 +- proximal/algorithms/invert.py | 2 +- proximal/algorithms/linearized_admm.py | 2 +- proximal/algorithms/merge.py | 4 ++-- proximal/algorithms/pock_chambolle.py | 2 +- proximal/algorithms/problem.py | 14 +++++++------- proximal/prox_fns/prox_fn.py | 6 +++--- proximal/utils/cuda_codegen.py | 2 +- pyproject.toml | 3 ++- 9 files changed, 19 insertions(+), 18 deletions(-) diff --git a/proximal/algorithms/absorb.py b/proximal/algorithms/absorb.py index 74e31b5..6f98478 100644 --- a/proximal/algorithms/absorb.py +++ b/proximal/algorithms/absorb.py @@ -40,7 +40,7 @@ def absorb_lin_op(prox_fn): if isinstance(prox_fn.lin_op, Variable): return [prox_fn] # Absorb a lin op into sum_entries/zero. - if type(prox_fn) == zero_prox and prox_fn.gamma == 0: + if isinstance(prox_fn, zero_prox) and prox_fn.gamma == 0: outputs = [] inputs = [prox_fn.c] for arg in prox_fn.lin_op.input_nodes: diff --git a/proximal/algorithms/invert.py b/proximal/algorithms/invert.py index 4422c06..2dceb6c 100644 --- a/proximal/algorithms/invert.py +++ b/proximal/algorithms/invert.py @@ -54,7 +54,7 @@ def get_diag_quads(prox_fns, freq): quad_funcs = [fn for fn in prox_fns if isinstance(fn, sum_squares)] if freq: return [fn for fn in quad_funcs if fn.lin_op.is_diag(freq=True) and - type(fn) == sum_squares] + isinstance(fn, sum_squares)] else: return [fn for fn in quad_funcs if fn.lin_op.is_diag(freq=False)] diff --git a/proximal/algorithms/linearized_admm.py b/proximal/algorithms/linearized_admm.py index e6adab8..95524bc 100644 --- a/proximal/algorithms/linearized_admm.py +++ b/proximal/algorithms/linearized_admm.py @@ -18,7 +18,7 @@ def partition(prox_fns, try_diagonalize=True): omega_fns = [] if len(quad_fns) == 0: for fn in prox_fns: - if type(fn.lin_op) == Variable: + if isinstance(fn.lin_op, Variable): split_fn = [fn] break omega_fns = split_fn diff --git a/proximal/algorithms/merge.py b/proximal/algorithms/merge.py index 4e59707..f25ff4b 100644 --- a/proximal/algorithms/merge.py +++ b/proximal/algorithms/merge.py @@ -30,9 +30,9 @@ def can_merge(lh_prox, rh_prox): """ # Lin ops must be the same. if lh_prox.lin_op == rh_prox.lin_op: - if type(lh_prox) == zero_prox or type(rh_prox) == zero_prox: + if isinstance(lh_prox, zero_prox) or isinstance(rh_prox, zero_prox): return True - elif type(lh_prox) == sum_squares or type(rh_prox) == sum_squares: + if isinstance(lh_prox, sum_squares) or isinstance(rh_prox, sum_squares): return True return False diff --git a/proximal/algorithms/pock_chambolle.py b/proximal/algorithms/pock_chambolle.py index 61f32f2..3266063 100644 --- a/proximal/algorithms/pock_chambolle.py +++ b/proximal/algorithms/pock_chambolle.py @@ -20,7 +20,7 @@ def partition(prox_fns, try_diagonalize=True): omega_fns = [] if len(quad_fns) == 0: for fn in prox_fns: - if type(fn.lin_op) == Variable: + if isinstance(fn.lin_op, Variable): split_fn = [fn] break omega_fns = split_fn diff --git a/proximal/algorithms/problem.py b/proximal/algorithms/problem.py index 494d362..5938290 100644 --- a/proximal/algorithms/problem.py +++ b/proximal/algorithms/problem.py @@ -107,13 +107,13 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap # Absorb offsets. prox_fns = [absorb.absorb_offset(fn) for fn in prox_fns] # TODO more analysis of what solver to use. - + if show_graph: print("Computational graph before optimizing:") graph_visualize(prox_fns, filename = show_graph if type(show_graph) is str else None) - + # Short circuit with one function. - if len(prox_fns) == 1 and type(prox_fns[0].lin_op) == Variable: + if len(prox_fns) == 1 and isinstance(prox_fns[0].lin_op, Variable): fn = prox_fns[0] var = fn.lin_op var.value = fn.prox(0, np.zeros(fn.lin_op.shape)) @@ -138,7 +138,7 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap L.norm_bound(output_mags) if NotImplemented not in output_mags: assert len(output_mags) == 1 - + x = random(L.input_size) x = x / LA.norm(x) y = np.zeros(L.output_size) @@ -148,7 +148,7 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap if ny > output_mags[0]: raise RuntimeError("wrong implementation of norm!") print("%.3f <= ||K|| = %.3f (%.3f)" % (ny, output_mags[0], nL2)) - + # Scale the problem. if self.scale: K = CompGraph(vstack([fn.lin_op for fn in psi_fns]), @@ -170,7 +170,7 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap # test adjoints L = CompGraph(vstack([fn.lin_op for fn in psi_fns])) from numpy.random import random - + x = random(L.input_size) yt = np.zeros(L.output_size) #print("x=", x) @@ -193,7 +193,7 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap raise RuntimeError("Unmatched adjoints: " + str(r)) else: print("Adjoint test passed.", r) - + if self.implem == Impl['pycuda']: kwargs['adapter'] = PyCudaAdapter() opt_val = module.solve(psi_fns, omega_fns, diff --git a/proximal/prox_fns/prox_fn.py b/proximal/prox_fns/prox_fn.py index 53cbb6c..c2a6b8e 100644 --- a/proximal/prox_fns/prox_fn.py +++ b/proximal/prox_fns/prox_fn.py @@ -216,7 +216,7 @@ def prox_cuda(self, rho, v, *args, **kwargs): if hasattr(self, "_prox_cuda"): if self.kernel_cuda_prox is None: self.gen_cuda_code() - if not type(v) == gpuarray.GPUArray: + if not isinstance(v, gpuarray.GPUArray): v = gpuarray.to_gpu(v.astype(np.float32)) xhat = gpuarray.zeros(v.shape, dtype=np.float32) if "offset" in kwargs: @@ -271,7 +271,7 @@ def __add__(self, other): """ if isinstance(other, ProxFn): return [self, other] - elif type(other) == list: + elif isinstance(other, list): return [self] + other else: return NotImplemented @@ -279,7 +279,7 @@ def __add__(self, other): def __radd__(self, other): """Called for list + ProxFn. """ - if type(other) == list: + if isinstance(other, list): return other + [self] else: return NotImplemented diff --git a/proximal/utils/cuda_codegen.py b/proximal/utils/cuda_codegen.py index cce0b4a..ab398c1 100644 --- a/proximal/utils/cuda_codegen.py +++ b/proximal/utils/cuda_codegen.py @@ -449,7 +449,7 @@ def gen_code(self, fcn, parent = None, shape = None): except AttributeError: buffers = [] for aname, aval in buffers: - if type(aval) == np.ndarray and aval.dtype == np.int32: + if isinstance(aval, np.ndarray) and aval.dtype == np.int32: aval = gpuarray.to_gpu(aval) self.cuda_args.append( (aname, aval, "int") ) else: diff --git a/pyproject.toml b/pyproject.toml index 2df31f8..2b26be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,10 @@ [tool.ruff] line-length=127 target-version="py310" +exclude=['proximal/halide/subprojects/'] [tool.ruff.lint] -select=["E9", "F63", "F7", "F82", "C90"] +select=["E9", "E713", "E721", "F63", "F7", "F82", "C90"] [tool.ruff.lint.mccabe] max-complexity=30