diff --git a/project/datasets.py b/project/datasets.py index 49d0bf0e..3a2a5754 100644 --- a/project/datasets.py +++ b/project/datasets.py @@ -22,10 +22,10 @@ def make_pts(N): class Graph: - def __init__(self, vis=False): + def __init__(self, vis=False, vis_args={}): self.gifs = [] if vis: - self.vis = visdom.Visdom() + self.vis = visdom.Visdom(**vis_args) else: self.vis = None self.first = True @@ -74,8 +74,8 @@ def graph(self, outfile, model=None): class Simple(Graph): - def __init__(self, N, vis=False): - super().__init__(vis) + def __init__(self, N, vis=False, vis_args={}): + super().__init__(vis, vis_args) self.N = N self.X = make_pts(N) self.y = [] @@ -85,8 +85,8 @@ def __init__(self, N, vis=False): class Split(Graph): - def __init__(self, N, vis=False): - super().__init__(vis) + def __init__(self, N, vis=False, vis_args={}): + super().__init__(vis, vis_args) self.N = N self.X = make_pts(N) self.y = [] @@ -96,8 +96,8 @@ def __init__(self, N, vis=False): class Xor(Graph): - def __init__(self, N, vis=False): - super().__init__(vis) + def __init__(self, N, vis=False, vis_args={}): + super().__init__(vis, vis_args) self.N = N self.X = make_pts(N) self.y = []