Skip to content

Commit

Permalink
Resolve linter error E721
Browse files Browse the repository at this point in the history
  • Loading branch information
antonysigma committed Oct 5, 2024
1 parent b6d03db commit 07f183c
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion proximal/algorithms/absorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def absorb_lin_op(prox_fn):
if isinstance(prox_fn.lin_op, Variable):
return [prox_fn]
# Absorb a lin op into sum_entries/zero.
if type(prox_fn) == zero_prox and prox_fn.gamma == 0:
if isinstance(prox_fn, zero_prox) and prox_fn.gamma == 0:
outputs = []
inputs = [prox_fn.c]
for arg in prox_fn.lin_op.input_nodes:
Expand Down
2 changes: 1 addition & 1 deletion proximal/algorithms/invert.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_diag_quads(prox_fns, freq):
quad_funcs = [fn for fn in prox_fns if isinstance(fn, sum_squares)]
if freq:
return [fn for fn in quad_funcs if fn.lin_op.is_diag(freq=True) and
type(fn) == sum_squares]
isinstance(fn, sum_squares)]
else:
return [fn for fn in quad_funcs if fn.lin_op.is_diag(freq=False)]

Expand Down
2 changes: 1 addition & 1 deletion proximal/algorithms/linearized_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def partition(prox_fns, try_diagonalize=True):
omega_fns = []
if len(quad_fns) == 0:
for fn in prox_fns:
if type(fn.lin_op) == Variable:
if isinstance(fn.lin_op, Variable):
split_fn = [fn]
break
omega_fns = split_fn
Expand Down
4 changes: 2 additions & 2 deletions proximal/algorithms/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def can_merge(lh_prox, rh_prox):
"""
# Lin ops must be the same.
if lh_prox.lin_op == rh_prox.lin_op:
if type(lh_prox) == zero_prox or type(rh_prox) == zero_prox:
if isinstance(lh_prox, zero_prox) or isinstance(rh_prox, zero_prox):
return True
elif type(lh_prox) == sum_squares or type(rh_prox) == sum_squares:
if isinstance(lh_prox, sum_squares) or isinstance(rh_prox, sum_squares):
return True

return False
Expand Down
2 changes: 1 addition & 1 deletion proximal/algorithms/pock_chambolle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def partition(prox_fns, try_diagonalize=True):
omega_fns = []
if len(quad_fns) == 0:
for fn in prox_fns:
if type(fn.lin_op) == Variable:
if isinstance(fn.lin_op, Variable):
split_fn = [fn]
break
omega_fns = split_fn
Expand Down
14 changes: 7 additions & 7 deletions proximal/algorithms/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap
# Absorb offsets.
prox_fns = [absorb.absorb_offset(fn) for fn in prox_fns]
# TODO more analysis of what solver to use.

if show_graph:
print("Computational graph before optimizing:")
graph_visualize(prox_fns, filename = show_graph if type(show_graph) is str else None)

# Short circuit with one function.
if len(prox_fns) == 1 and type(prox_fns[0].lin_op) == Variable:
if len(prox_fns) == 1 and isinstance(prox_fns[0].lin_op, Variable):
fn = prox_fns[0]
var = fn.lin_op
var.value = fn.prox(0, np.zeros(fn.lin_op.shape))
Expand All @@ -138,7 +138,7 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap
L.norm_bound(output_mags)
if NotImplemented not in output_mags:
assert len(output_mags) == 1

x = random(L.input_size)
x = x / LA.norm(x)
y = np.zeros(L.output_size)
Expand All @@ -148,7 +148,7 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap
if ny > output_mags[0]:
raise RuntimeError("wrong implementation of norm!")
print("%.3f <= ||K|| = %.3f (%.3f)" % (ny, output_mags[0], nL2))

# Scale the problem.
if self.scale:
K = CompGraph(vstack([fn.lin_op for fn in psi_fns]),
Expand All @@ -170,7 +170,7 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap
# test adjoints
L = CompGraph(vstack([fn.lin_op for fn in psi_fns]))
from numpy.random import random

x = random(L.input_size)
yt = np.zeros(L.output_size)
#print("x=", x)
Expand All @@ -193,7 +193,7 @@ def solve(self, solver=None, test_adjoints = False, test_norm = False, show_grap
raise RuntimeError("Unmatched adjoints: " + str(r))
else:
print("Adjoint test passed.", r)

if self.implem == Impl['pycuda']:
kwargs['adapter'] = PyCudaAdapter()
opt_val = module.solve(psi_fns, omega_fns,
Expand Down
6 changes: 3 additions & 3 deletions proximal/prox_fns/prox_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def prox_cuda(self, rho, v, *args, **kwargs):
if hasattr(self, "_prox_cuda"):
if self.kernel_cuda_prox is None:
self.gen_cuda_code()
if not type(v) == gpuarray.GPUArray:
if not isinstance(v, gpuarray.GPUArray):
v = gpuarray.to_gpu(v.astype(np.float32))
xhat = gpuarray.zeros(v.shape, dtype=np.float32)
if "offset" in kwargs:
Expand Down Expand Up @@ -271,15 +271,15 @@ def __add__(self, other):
"""
if isinstance(other, ProxFn):
return [self, other]
elif type(other) == list:
elif isinstance(other, list):
return [self] + other
else:
return NotImplemented

def __radd__(self, other):
"""Called for list + ProxFn.
"""
if type(other) == list:
if isinstance(other, list):
return other + [self]
else:
return NotImplemented
Expand Down
2 changes: 1 addition & 1 deletion proximal/utils/cuda_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def gen_code(self, fcn, parent = None, shape = None):
except AttributeError:
buffers = []
for aname, aval in buffers:
if type(aval) == np.ndarray and aval.dtype == np.int32:
if isinstance(aval, np.ndarray) and aval.dtype == np.int32:
aval = gpuarray.to_gpu(aval)
self.cuda_args.append( (aname, aval, "int") )
else:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[tool.ruff]
line-length=127
target-version="py310"
exclude=['proximal/halide/subprojects/']

[tool.ruff.lint]
select=["E9", "F63", "F7", "F82", "C90"]
select=["E9", "E713", "E721", "F63", "F7", "F82", "C90"]

[tool.ruff.lint.mccabe]
max-complexity=30

0 comments on commit 07f183c

Please sign in to comment.