forked from KarchinLab/bigmhc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bigmhc_predict
83 lines (61 loc) · 2.2 KB
/
bigmhc_predict
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
------------------------------------------------------------------------
Copyright 2023 Benjamin Alexander Albert [Karchin Lab]
All Rights Reserved
BigMHC Academic License
predict.py
------------------------------------------------------------------------
"""
import os
import torch
import pandas as pd
import bigmhc.src.cli
from bigmhc.src.bigmhc import BigMHC
def predict(models, data, args):
if args.verbose:
print("starting prediction devices: {}".format(args.devices))
for x in range(len(models)):
models[x].eval()
models[x] = BigMHC.accelerate(
models[x],
devices=args.devices)
preds = list()
with torch.no_grad():
for idx,bat in enumerate(data):
if args.verbose:
print("batch {}/{}".format(idx+1, len(data)))
out = list()
att = list()
for model in models:
dev = next(model.parameters()).device
_out,_att = model(
mhc=bat.mhc.to(dev),
pep=bat.pep.to(dev))
out.append(torch.sigmoid(_out))
att.append(_att)
out = torch.mean(torch.stack(out),dim=0)
att = torch.mean(torch.stack(att),dim=0)
rawbat = data.dataset.getbat(idx=idx, enc=False)
rawbat[args.modelname] = out.cpu().numpy()
if args.saveatt:
attdict = dict()
for x in range(att.shape[1]):
attdict["att_{}".format(x)] = att[:,x].cpu()
attdf = pd.DataFrame(attdict, index=rawbat.index)
rawbat = pd.concat((rawbat, attdf), axis=1)
preds.append(rawbat)
return pd.concat(preds).sort_index()
def main():
args, data, models = bigmhc.src.cli.parseArgs(train=False)
preds = predict(models=models, data=data, args=args)
if not args.out:
args.out = args.input + ".prd"
if args.verbose:
print("writing predictions to {}".format(args.out))
if args.tgtcol is not None:
preds["tgt"] = preds["tgt"].astype(int)
preds.to_csv(args.out, index=False)
if __name__ == "__main__":
main()