Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
Worromots committed May 23, 2022
1 parent 501bb7f commit d965bd8
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
8 changes: 4 additions & 4 deletions attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


class Patten(nn.Module):
def __int__(self, in_dim, n):
super(Patten, self).__int__()
def __init__(self, in_dim, n):
super(Patten, self).__init__()
self.in_channel = in_dim

self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // n, kernel_size=1)
Expand All @@ -33,8 +33,8 @@ def forward(self, x):


class Catten(nn.Module):
def __int__(self, in_dim):
super(Catten, self).__int__()
def __init__(self, in_dim):
super(Catten, self).__init__()
self.channel_in = in_dim

self.gamma = nn.Parameter(torch.zeros(1))
Expand Down
7 changes: 5 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def init_extra(self):
raise Exception("数据集不存在,名字打错了")

def merge_from_args(self, args):
self.gpu = args.gpu
self.epochs = args.epochs
if args.gpu:
self.gpu = args.gpu
if args.epochs:
self.epochs = args.epochs
if args.train_seg:
self.train_mode = "seg"
elif args.train_dec:
Expand Down Expand Up @@ -106,3 +108,4 @@ def get_as_dict(self):
'input_h': self.input_h,
'input_c': self.input_c,
}
return params
4 changes: 2 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def parse_args():
parser.add_argument('--train_seg', type=str, required=False, help="train part")
parser.add_argument('--train_dec', type=str, required=False, help="train part")
parser.add_argument('--train_total', type=str, required=False, help="train part")

return parser
args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()
Expand Down
7 changes: 7 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def train(self):

self._set_rets_path()
self._create_results_dirs()
self.print_run_params()

device = self._get_device()
model = self._get_model().to(device)
Expand Down Expand Up @@ -333,3 +334,9 @@ def training_iteration(self, data, device, model, criterion_seg, criterion_dec,
optimizer.zero_grad()

return total_loss_seg, total_loss_dec, total_loss, total_correct


def print_run_params(self):
for l in sorted(map(lambda e: e[0] + ":" + str(e[1]) + "\n", self.cfg.get_as_dict().items())):
k, v = l.split(":")
self._log(f"{k:25s} : {str(v.strip())}")

0 comments on commit d965bd8

Please sign in to comment.