forked from KarchinLab/bigmhc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bigmhc_train
94 lines (68 loc) · 2.31 KB
/
bigmhc_train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
------------------------------------------------------------------------
Copyright 2023 Benjamin Alexander Albert [Karchin Lab]
All Rights Reserved
BigMHC Academic License
train.py
------------------------------------------------------------------------
"""
import os
import torch
import bigmhc.src.cli
from bigmhc.src.bigmhc import BigMHC
def train(model, data, args):
if args.verbose:
print("starting training on devices: {}".format(args.devices))
model = BigMHC.accelerate(
model=model,
devices=args.devices).train()
dev = next(model.parameters()).device
lossf = torch.nn.BCEWithLogitsLoss().to(dev)
if args.transferlearn:
if isinstance(model, torch.nn.DataParallel):
module = model.module
else:
module = model
optmz = torch.optim.AdamW(
params=[v for k,v in module.named_parameters()
if k in BigMHC.tllayers()],
lr=args.lr)
else:
optmz = torch.optim.AdamW(
params=model.parameters(),
lr=args.lr)
for ep in range(args.epochs):
eperr = 0
data.dataset.makebats(
maxbat=args.maxbat,
shuffle=True,
negfrac=None if args.transferlearn else 0.99)
for bat in data:
tgt = bat.tgt.float().to(dev)
if args.transferlearn and not tgt.sum():
continue
optmz.zero_grad()
out,_ = model(
mhc=bat.mhc.float().to(dev),
pep=bat.pep.float().to(dev))
err = lossf(out, tgt)
err.backward()
eperr += float(err) / len(data)
optmz.step()
if args.verbose:
print("ep {} loss: {}".format(ep+1, eperr))
if args.out:
epdir = os.path.join(args.out, "ep{}".format(ep+1))
model = BigMHC.decelerate(model)
model.save(epdir, tl=args.transferlearn)
model = BigMHC.accelerate(model, args.devices)
def main():
args, data, model = bigmhc.src.cli.parseArgs(train=True)
if len(model) > 1:
raise ValueError(
"training multiple models is currently unsupported")
train(model=model[0], data=data, args=args)
if __name__ == "__main__":
main()