-
Notifications
You must be signed in to change notification settings - Fork 17
/
pretrain_utils.py
44 lines (37 loc) · 1.03 KB
/
pretrain_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import scanpy as sc
import pandas as pd
from anndata import AnnData
import warnings
from builtins import int
warnings.filterwarnings('ignore')
import losses, miners, distances, reducers, testers
from utils.transformations import *
from utils.accuracy_calculator import AccuracyCalculator
import torch
import torch.optim as optim
from torch import nn
import numpy as np
import utils.logging_presets as logging_presets
import record_keeper
import torch.nn.functional as F
from tqdm import tqdm
from pathlib import Path
import random
#from torch.utils.tensorboard import SummaryWriter
import argparse
from datetime import datetime
import sys
# https://www.microsoft.com/en-us/research/blog/less-pain-more-gain-a-simple-method-for-vae-training-with-less-of-that-kl-vanishing-agony/
def get_kld_cycle(epoch, period=20):
'''
0-10: 0 to 1
10-20 1
21-30 0 to 1
30-40 1
'''
ct = epoch % period
pt = epoch % (period//2)
if ct >= period//2:
return 1
else:
return min(1, (pt) / (period//2))