Skip to content

Commit

Permalink
more type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-c committed Jun 26, 2024
1 parent 9022dcc commit 14ec092
Show file tree
Hide file tree
Showing 14 changed files with 259 additions and 178 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
MODULE1=wes_client
MODULE2=wes_service
PACKAGE=wes-service
EXTRAS=[toil,arvados]
EXTRAS=[toil]

# `SHELL=bash` doesn't work for some, so don't use BASH-isms like
# `[[` conditional expressions.
Expand Down
36 changes: 20 additions & 16 deletions cwl_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,32 @@
import threading
import time
import copy
import shutil
from typing import List, Dict, Any, Generator, Tuple

import werkzeug.wrappers.response

app = Flask(__name__)

jobs_lock = threading.Lock()
jobs = []
jobs: List["Job"] = []


class Job(threading.Thread):
def __init__(self, jobid, path, inputobj):
def __init__(self, jobid: int, path: str, inputobj: bytes) -> None:
super().__init__()
self.jobid = jobid
self.path = path
self.inputobj = inputobj
self.updatelock = threading.Lock()
self.begin()

def begin(self):
def begin(self) -> None:
loghandle, self.logname = tempfile.mkstemp()
with self.updatelock:
self.outdir = tempfile.mkdtemp()
self.proc = subprocess.Popen(
["cwl-runner", self.path, "-"],
[shutil.which("cwl-runner") or "cwl-runner", self.path, "-"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=loghandle,
Expand All @@ -44,7 +48,7 @@ def begin(self):
"output": None,
}

def run(self):
def run(self) -> None:
self.stdoutdata, self.stderrdata = self.proc.communicate(self.inputobj)
if self.proc.returncode == 0:
outobj = yaml.load(self.stdoutdata, Loader=yaml.FullLoader)
Expand All @@ -55,31 +59,31 @@ def run(self):
with self.updatelock:
self.status["state"] = "Failed"

def getstatus(self):
def getstatus(self) -> Dict[str, Any]:
with self.updatelock:
return self.status.copy()

def cancel(self):
def cancel(self) -> None:
if self.status["state"] == "Running":
self.proc.send_signal(signal.SIGQUIT)
with self.updatelock:
self.status["state"] = "Canceled"

def pause(self):
def pause(self) -> None:
if self.status["state"] == "Running":
self.proc.send_signal(signal.SIGTSTP)
with self.updatelock:
self.status["state"] = "Paused"

def resume(self):
def resume(self) -> None:
if self.status["state"] == "Paused":
self.proc.send_signal(signal.SIGCONT)
with self.updatelock:
self.status["state"] = "Running"


@app.route("/run", methods=["POST"])
def runworkflow():
def runworkflow() -> werkzeug.wrappers.response.Response:
path = request.args["wf"]
with jobs_lock:
jobid = len(jobs)
Expand All @@ -90,7 +94,7 @@ def runworkflow():


@app.route("/jobs/<int:jobid>", methods=["GET", "POST"])
def jobcontrol(jobid):
def jobcontrol(jobid: int) -> Tuple[str, int]:
with jobs_lock:
job = jobs[jobid]
if request.method == "POST":
Expand All @@ -104,10 +108,10 @@ def jobcontrol(jobid):
job.resume()

status = job.getstatus()
return json.dumps(status, indent=4), 200, ""
return json.dumps(status, indent=4), 200


def logspooler(job):
def logspooler(job: Job) -> Generator[str, None, None]:
with open(job.logname) as f:
while True:
r = f.read(4096)
Expand All @@ -121,18 +125,18 @@ def logspooler(job):


@app.route("/jobs/<int:jobid>/log", methods=["GET"])
def getlog(jobid):
def getlog(jobid: int) -> Response:
with jobs_lock:
job = jobs[jobid]
return Response(logspooler(job))


@app.route("/jobs", methods=["GET"])
def getjobs():
def getjobs() -> Response:
with jobs_lock:
jobscopy = copy.copy(jobs)

def spool(jc):
def spool(jc: List[Job]) -> Generator[str, None, None]:
yield "["
first = True
for j in jc:
Expand Down
19 changes: 9 additions & 10 deletions cwltool_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,33 @@
import cwltool.main
import tempfile
import logging
import StringIO
from io import StringIO
import json
from typing import List, Union

_logger = logging.getLogger("cwltool")
_logger.setLevel(logging.ERROR)


def main(args=None):
if args is None:
args = sys.argv[1:]

def main(args: List[str] = sys.argv[1:]) -> int:
if len(args) == 0:
print("Workflow must be on command line")
return 1

parser = cwltool.main.arg_parser()
parser = cwltool.argparser.arg_parser()
parsedargs = parser.parse_args(args)

a = True
a: Union[bool, str] = True
while a:
a = True
msg = ""
while a and a != "\n":
a = sys.stdin.readline()
msg += a

outdir = tempfile.mkdtemp("", parsedargs.tmp_outdir_prefix)

t = StringIO.StringIO(msg)
err = StringIO.StringIO()
t = StringIO(msg)
err = StringIO()
if (
cwltool.main.main(
["--outdir=" + outdir] + args + ["-"], stdin=t, stderr=err
Expand All @@ -43,6 +40,8 @@ def main(args=None):
sys.stdout.write(json.dumps({"cwl:error": err.getvalue()}))
sys.stdout.write("\n\n")
sys.stdout.flush()
a = True
return 0


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions mypy-requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
mypy==1.10.1
types-PyYAML
types-requests
types-setuptools
8 changes: 8 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[mypy]
show_error_context = true
show_column_numbers = true
show_error_codes = true
pretty = true
strict = true
[mypy-ruamel.*]
ignore_errors = True
24 changes: 13 additions & 11 deletions test/test_client_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class IntegrationTest(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
dirname, filename = os.path.split(os.path.abspath(__file__))
self.testdata_dir = dirname + "data"
self.local = {
Expand All @@ -37,26 +37,28 @@ def setUp(self):
"pyWithPrefix": ("3", "PY"),
}

def tearDown(self):
def tearDown(self) -> None:
unittest.TestCase.tearDown(self)

def test_expand_globs(self):
def test_expand_globs(self) -> None:
"""Asserts that wes_client.expand_globs() sees the same files in the cwd as 'ls'."""
files = subprocess.check_output(["ls", "-1", "."])

# python 2/3 bytestring/utf-8 compatibility
if isinstance(files, str):
files = files.split("\n")
files2 = files.split("\n")
else:
files = files.decode("utf-8").split("\n")
files2 = files.decode("utf-8").split("\n")

if "" in files:
files.remove("")
files = ["file://" + os.path.abspath(f) for f in files]
if "" in files2:
files2.remove("")
files2 = ["file://" + os.path.abspath(f) for f in files2]
glob_files = expand_globs("*")
assert set(files) == glob_files, "\n" + str(set(files)) + "\n" + str(glob_files)
assert set(files2) == glob_files, (
"\n" + str(set(files2)) + "\n" + str(glob_files)
)

def testSupportedFormatChecking(self):
def testSupportedFormatChecking(self) -> None:
"""
Check that non-wdl, -python, -cwl files are rejected.
Expand All @@ -75,7 +77,7 @@ def testSupportedFormatChecking(self):
with self.assertRaises(TypeError):
wf_info(location)

def testFileLocationChecking(self):
def testFileLocationChecking(self) -> None:
"""
Check that the function rejects unsupported file locations.
Expand Down
Loading

0 comments on commit 14ec092

Please sign in to comment.