From 85b4608aac8fbe5fdde53525d9e752e6e8f199a4 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Thu, 19 Sep 2019 23:29:51 -0400 Subject: [PATCH 1/9] Removing lua torch support --- README.md | 43 ++++++++++------------------------------- py/visdom/__init__.py | 24 +++++++++++------------ py/visdom/server.py | 45 ++++++------------------------------------- th/init.lua | 4 ++++ 4 files changed, 32 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index 470be9b9..ecaeb967 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ![visdom_big](https://lh3.googleusercontent.com/-bqH9UXCw-BE/WL2UsdrrbAI/AAAAAAAAnYc/emrxwCmnrW4_CLTyyUttB0SYRJ-i4CCiQCLcB/s0/Screen+Shot+2017-03-06+at+10.51.02+AM.png"visdom_big") -A flexible tool for creating, organizing, and sharing visualizations of live, rich data. Supports Torch and Numpy. +A flexible tool for creating, organizing, and sharing visualizations of live, rich data. Supports Python. * [Overview](#overview) * [Concepts](#concepts) @@ -140,17 +140,13 @@ Using the view dropdown it is possible to select previously saved views, restori ## Setup -Requires Python 2.7/3 (and optionally Torch7) +Requires Python 2.7/3 ```bash # Install Python server and client from pip # (STABLE VERSION, NOT ALL CURRENT FEATURES ARE SUPPORTED) pip install visdom -# Install Torch client -# (STABLE VERSION, NOT ALL CURRENT FEATURES ARE SUPPORTED) -luarocks install visdom - ``` ```bash @@ -159,9 +155,6 @@ pip install -e . # If the above runs into issues, you can try the below easy_install . -# Install Torch client from source (from th directory) -luarocks make - ``` ## Usage @@ -174,7 +167,7 @@ Start the server (probably in a `screen` or `tmux`) from the command line: Visdom now can be accessed by going to `http://localhost:8097` in your browser, or your own host address if specified. -> The `visdom` command is equivalent to running `python -m visdom.server`. +> The `visdom` command is equivalent to running `python -m visdom.server`. >If the above does not work, try using an SSH tunnel to your server by adding the following line to your local `~/.ssh/config`: ```LocalForward 127.0.0.1:8097 127.0.0.1:8097```. @@ -200,29 +193,10 @@ vis.text('Hello, world!') vis.image(np.ones((3, 10, 10))) ``` -#### Torch example -```lua -require 'image' -vis = require 'visdom'() -vis:text{text = 'Hello, world!'} -vis:image{img = image.fabio()} -``` - -Some users have reported issues when connecting Lua clients to the Visdom server. -A potential work-around may be to switch off IPv6: -``` -vis = require 'visdom'() -vis.ipv6 = false -- switches off IPv6 -vis:text{text = 'Hello, world!'} -``` - - ### Demos ```bash python example/demo.py -th example/demo1.lua -th example/demo2.lua ``` @@ -240,7 +214,7 @@ The python visdom client takes a few options: - `http_proxy_host`: host to proxy your incoming socket through (default: `None`) - `http_proxy_port`: port to proxy your incoming socket through (default: `None`) -Other options are either currently unused (endpoint, ipv6) or used for internal functionality (send allows the visdom server to replicate events for the lua client). +Other options are either currently unused (endpoint, ipv6) or used for internal functionality. ### Basics Visdom offers the following basic visualization functions: @@ -394,12 +368,12 @@ packages installed to use this option. #### vis.plotlyplot -This function draws a Plotly `Figure` object. It does not explicitly take options as it assumes you have already explicitly configured the figure's `layout`. +This function draws a Plotly `Figure` object. It does not explicitly take options as it assumes you have already explicitly configured the figure's `layout`. -> **Note** You must have the `plotly` Python package installed to use this function. It can typically be installed by running `pip install plotly`. +> **Note** You must have the `plotly` Python package installed to use this function. It can typically be installed by running `pip install plotly`. #### vis.save -This function saves the `envs` that are alive on the visdom server. It takes input a list (in python) or table (in lua) of env ids to be saved. +This function saves the `envs` that are alive on the visdom server. It takes input a list of env ids to be saved. ### Plotting Further details on the wrapped plotting functions are given below. @@ -638,6 +612,9 @@ Arguments: - [ ] Filtering through windows with regex by title (or meta field) - [ ] Compiling react by python server at runtime +## Note on Lua Torch Support +Support for Lua Torch was deprecated following `v0.1.8.4`. If you'd like to use torch support, you'll need to download that release. You can follow the usage instructions there, but it is no longer officially supported. + ## Contributing See guidelines for contributing [here.](./CONTRIBUTING.md) diff --git a/py/visdom/__init__.py b/py/visdom/__init__.py index 7304503c..a320ccb3 100644 --- a/py/visdom/__init__.py +++ b/py/visdom/__init__.py @@ -225,11 +225,11 @@ def _assert_opts(opts): if opts.get('columnnames'): assert isinstance(opts.get('columnnames'), list), \ - 'columnnames should be a table with column names' + 'columnnames should be a list with column names' if opts.get('rownames'): assert isinstance(opts.get('rownames'), list), \ - 'rownames should be a table with row names' + 'rownames should be a list with row names' if opts.get('jpgquality'): assert isnum(opts.get('jpgquality')), \ @@ -462,7 +462,7 @@ def _send(self, msg, endpoint='events', quiet=False, from_log=False): def save(self, envs): """ This function allows the user to save envs that are alive on the - Tornado server. The envs can be specified as a table (list) of env ids. + Tornado server. The envs can be specified as a list of env ids. """ assert isinstance(envs, list), 'envs should be a list' if len(envs) > 0: @@ -686,7 +686,7 @@ def plotlyplot(self, figure, win=None, env=None): """ This function draws a Plotly 'Figure' object. It does not explicitly take options as it assumes you have already explicitly configured the figure's layout. - Note: You must have the 'plotly' Python package installed to use this function. + Note: You must have the 'plotly' Python package installed to use this function. """ try: import plotly @@ -937,7 +937,7 @@ def scatter(self, X, Y=None, win=None, env=None, opts=None, update=None, - `opts.markersize` : marker size (`number`; default = `'10'`) - `opts.markercolor` : marker color (`np.array`; default = `None`) - `opts.textlabels` : text label for each point (`list`: default = `None`) - - `opts.legend` : `table` containing legend names + - `opts.legend` : `list` containing legend names """ if update == 'remove': assert win is not None @@ -1094,7 +1094,7 @@ def line(self, Y, X=None, win=None, env=None, opts=None, update=None, - `opts.markers` : show markers (`boolean`; default = `false`) - `opts.markersymbol`: marker symbol (`string`; default = `'dot'`) - `opts.markersize` : marker size (`number`; default = `'10'`) - - `opts.legend` : `table` containing legend names + - `opts.legend` : `list` containing legend names If `update` is specified, the figure will be updated without creating a new plot -- this can be used for efficient updating. @@ -1154,8 +1154,8 @@ def heatmap(self, X, win=None, env=None, opts=None): - `opts.colormap`: colormap (`string`; default = `'Viridis'`) - `opts.xmin` : clip minimum value (`number`; default = `X:min()`) - `opts.xmax` : clip maximum value (`number`; default = `X:max()`) - - `opts.columnnames`: `table` containing x-axis labels - - `opts.rownames`: `table` containing y-axis labels + - `opts.columnnames`: `list` containing x-axis labels + - `opts.rownames`: `list` containing y-axis labels """ assert X.ndim == 2, 'data should be two-dimensional' @@ -1204,9 +1204,9 @@ def bar(self, X, Y=None, win=None, env=None, opts=None): The following plot-specific `opts` are currently supported: - - `opts.rownames`: `table` containing x-axis labels + - `opts.rownames`: `list` containing x-axis labels - `opts.stacked` : stack multiple columns in `X` - - `opts.legend` : `table` containing legend labels + - `opts.legend` : `list` containing legend labels """ X = np.squeeze(X) assert X.ndim == 1 or X.ndim == 2, 'X should be one or two-dimensional' @@ -1502,7 +1502,7 @@ def stem(self, X, Y=None, win=None, env=None, opts=None): The following `opts` are supported: - `opts.colormap`: colormap (`string`; default = `'Viridis'`) - - `opts.legend` : `table` containing legend names + - `opts.legend` : `list` containing legend names """ X = np.squeeze(X) @@ -1545,7 +1545,7 @@ def pie(self, X, win=None, env=None, opts=None): The following `opts` are supported: - - `opts.legend`: `table` containing legend names + - `opts.legend`: `list` containing legend names """ X = np.squeeze(X) diff --git a/py/visdom/server.py b/py/visdom/server.py index 62a45a03..44269271 100644 --- a/py/visdom/server.py +++ b/py/visdom/server.py @@ -24,7 +24,6 @@ import traceback from os.path import expanduser -import visdom from zmq.eventloop import ioloop ioloop.install() # Needs to happen before any tornado imports! @@ -425,18 +424,6 @@ def register_window(self, p, eid): self.write(p['id']) -def unpack_lua(req_args): - if req_args['is_table']: - if isinstance(req_args['val'], dict): - return {k: unpack_lua(v) for (k, v) in req_args['val'].items()} - else: - return [unpack_lua(v) for v in req_args['val']] - elif req_args['is_tensor']: - return visdom.from_t7(req_args['val'], b64=True) - else: - return req_args['val'] - - class PostHandler(BaseHandler): def initialize(self, app): self.state = app.state @@ -444,9 +431,6 @@ def initialize(self, app): self.sources = app.sources self.port = app.port self.env_path = app.env_path - self.vis = visdom.Visdom( - port=self.port, send=False, use_incoming_socket=False - ) self.handlers = { 'update': UpdateHandler, 'save': SaveHandler, @@ -455,35 +439,18 @@ def initialize(self, app): 'delete_env': DeleteEnvHandler, } - def func(self, req): - args, kwargs = req['args'], req.get('kwargs', {}) - - args = (unpack_lua(a) for a in args) - - for k in kwargs: - v = kwargs[k] - kwargs[k] = unpack_lua(v) - - func = getattr(self.vis, req['func']) - - return func(*args, **kwargs) - def post(self): req = tornado.escape.json_decode( tornado.escape.to_basestring(self.request.body) ) if req.get('func') is not None: - try: - req, endpoint = self.func(req) - if (endpoint != 'events'): - # Process the request using the proper handler - self.handlers[endpoint].wrap_func(self, req) - return - except Exception: - # get traceback and send it back - print(traceback.format_exc()) - return self.write(traceback.format_exc()) + raise Exception( + 'Support for Lua Torch was deprecated following `v0.1.8.4`. ' + "If you'd like to use torch support, you'll need to download " + "that release. You can follow the usage instructions there, " + "but it is no longer officially supported." + ) eid = extract_eid(req) p = window(req) diff --git a/th/init.lua b/th/init.lua index 60f4119f..68a2ab87 100644 --- a/th/init.lua +++ b/th/init.lua @@ -29,6 +29,10 @@ M.__init = argcheck{ visualization server that wraps plot.ly to show scalable, high-quality visualizations in the browser. + Note: The lua Torch client for visdom was deprecated after visdom + v0.1.8.4, so if you'd like to use visdom for torch, you'll have to + download that specific tag of visdom from the github. + The server can be started with the `server.py` script. The server defaults to port 8097. When the server is running on `domain.com:8097`, then visit that web address in your browser to see the visualization desktop. From b4f243a329355d35db01aa052552cba29925177d Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Fri, 20 Sep 2019 00:45:57 -0400 Subject: [PATCH 2/9] Removing demo, small merge fix' --- example/demo1.lua | 333 -------------------------------------------- example/demo2.lua | 196 -------------------------- py/visdom/server.py | 1 + 3 files changed, 1 insertion(+), 529 deletions(-) delete mode 100644 example/demo1.lua delete mode 100644 example/demo2.lua diff --git a/example/demo1.lua b/example/demo1.lua deleted file mode 100644 index c9a67b62..00000000 --- a/example/demo1.lua +++ /dev/null @@ -1,333 +0,0 @@ ---[[ - -Copyright 2017-present, Facebook, Inc. -All rights reserved. - -This source code is licensed under the license found in the -LICENSE file in the root directory of this source tree. - -]]-- - --- dependencies: -require 'torch' -require 'image' -local paths = require 'paths' - --- intialize visdom Torch client: -local visdom = require 'visdom' -local plot = visdom{server = 'http://localhost', port = 8097} -if not plot:check_connection() then - error('Could not connect, please ensure the visdom server is running') -end - --- text box demo: -local textwindow = plot:text{ - text = 'Hello, world! If I\'m still open, close failed' -} -local updatetextwindow = plot:text{ - text = 'Hello, world! If I don\'t have another line, update text failed.' -} -plot:text{text = 'Here\'s another line', win = updatetextwindow, append = true} - -plot:py_func{func='text', args={'Hello, world!'}} - --- image demo: -plot:image{ - img = image.fabio(), - opts = { - title = 'Fabio', - caption = 'Hello, I am Fabio ;)', - } -} - --- images demo: -plot:images{ - table = {torch.zeros(3, 200, 200) + 0.1, torch.zeros(3, 200, 200) + 0.2}, - opts = { - caption = 'I was a table of tensors...', - } -} - --- images demo: -plot:images{ - tensor = torch.randn(6, 3, 200, 200), - opts = { - caption = 'I was a 4D tensor...', - } -} - --- scatter plot demos: -plot:scatter{ - X = torch.randn(100, 2), - Y = torch.randn(100):gt(0):add(1):long(), - opts = { - legend = {'Apples', 'Pears'}, - xtickmin = -5, - xtickmax = 5, - xtickstep = .5, - ytickmin = -5, - ytickmax = 5, - ytickstep = .5, - markersymbol = 'cross-thin-open', - } -} -plot:scatter{ - X = torch.randn(100, 3), - Y = torch.randn(100):gt(0):add(1):long(), - opts = { - markersize = 5, - legend = {'Men', 'Women'}, - }, -} - --- 2D scatterplot with custom intensities (red channel): -local id = plot:scatter{ - X = torch.randn(255, 2), - opts = { - markersize = 10, - markercolor = torch.zeros(255):random(0, 255), - }, -} - --- check if win_exists works -local exists = plot:win_exists{ - win = id, -} -if not exists then error("created window doesn't exist") end - - -plot:line{ -- add new trace to scatter plot - X = torch.randn(255), - Y = torch.randn(255), - win = id, - name = 'new trace', - update = 'append', -} - --- 2D scatter plot with custom colors: -plot:scatter{ - X = torch.randn(255, 2), - opts = { - markersize = 10, - markercolor = torch.zeros(255, 3):random(0, 255), - }, -} - --- 2D scatter plot with custom colors per label: -plot:scatter{ - X = torch.randn(255, 2), - Y = torch.randn(255):gt(0):add(1):long(), -- two labels - opts = { - markersize = 10, - markercolor = torch.zeros(2, 3):random(0, 255), - }, -} - --- bar plot demos: -plot:bar{ - X = torch.randn(20) -} -plot:bar{ - X = torch.randn(5, 3):abs(), - opts = { - stacked = true, - legend = {'Facebook', 'Google', 'Twitter'}, - rownames = {'2012', '2013', '2014', '2015', '2016'}, - }, -} -plot:bar{ - X = torch.randn(20, 3), - opts = { - stacked = false, - legend = {'The Netherlands', 'France', 'United States'}, - }, -} - --- histogram demo: -plot:histogram{ - X = torch.randn(10000), - opts = {numbins = 20}, -} - --- heatmap demo: -local X = torch.cmul(torch.range(1, 10):reshape(1, 10):expand(5, 10), - torch.range(1, 5):reshape(5, 1):expand(5, 10)) -plot:heatmap{ - X = X, - opts = { - columnnames = {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'}, - rownames = {'y1', 'y2', 'y3', 'y4', 'y5'}, - colormap = 'Electric', - }, -} - --- contour plot demo: -local x = torch.range(1, 100):reshape(1, 100):expand(100, 100) -local y = torch.range(1, 100):reshape(100, 1):expand(100, 100) -local X = torch.add(x, -50):pow(2):add( - torch.add(y, -50):pow(2) -):div(-math.pow(20, 2)):exp() -plot:contour{X = X, opts = {colormap = 'Viridis'}} - --- surface plot demo: -plot:surf{X = X, opts = {colormap = 'Hot'}} - --- line plot demos: -local Y = torch.range(-4, 4, 0.05) -plot:line{ - Y = torch.cat(torch.cmul(Y, Y), torch.sqrt(Y + 5), 2), - X = torch.cat(Y, Y, 2), - opts = {markers = false} -} - -local id = plot:line{ - Y = torch.cat(torch.range(0, 10), torch.range(0, 10) + 5, 2), - X = torch.cat(torch.range(0, 10), torch.range(0, 10), 2), - opts = {markers = false} -} - -plot:py_func{ - func='line', - args = {torch.randn(10)}, - kwargs = {opts = {title = 'This is lua through python'}} -} - --- update trace demos: -plot:line{ - X = torch.cat(torch.range(11, 20), torch.range(11, 20), 2), - Y = torch.cat(torch.range(11, 20), torch.range(5, 14) * 2 + 5, 2), - win = id, - update = 'append', -} - -plot:line{ - X = torch.range(1, 10), - Y = torch.range(1, 10), - win = id, - name = '3', - update = 'append', -} - -plot:line{ - X = torch.range(1, 10), - Y = torch.range(11, 20), - win = id, - name = '4', - update = 'append', -} - --- stacked line plot demo: -local Y = torch.range(0, 4, 0.02) -plot:line{ - Y = torch.cat(torch.sqrt(Y), torch.sqrt(Y):add(2), 2), - X = torch.cat(Y, Y, 2), - opts = { - fillarea = true, - legend = false, - width = 400, - height = 400, - xlabel = 'Time', - ylabel = 'Volume', - ytype = 'log', - title = 'Stacked area plot', - marginleft = 30, - marginright = 30, - marginbottom = 80, - margintop = 30, - }, -} - --- boxplot demo: -local X = torch.randn(100, 2) -X:narrow(2, 2, 1):add(2) -plot:boxplot{ - X = X, - opts = { - legend = {'Men', 'Women'}, - }, -} - --- stem plot demo: -local Y = torch.range(0, 2 * math.pi, (2 * math.pi) / 70) -local X = torch.cat(torch.sin(Y), torch.cos(Y), 2) -plot:stem{ - X = X, - Y = Y, - opts = { - legend = {'Sine', 'Cosine'}, - }, -} - --- quiver demo: -local X = torch.range(0, 2, 0.2) -local Y = torch.range(0, 2, 0.2) -X = X:resize(1, X:nElement()):expand(X:nElement(), X:nElement()) -Y = Y:resize(Y:nElement(), 1):expand(Y:nElement(), Y:nElement()) -local U = torch.cos(X):cmul(Y) -local V = torch.sin(X):cmul(Y) -plot:quiver{ - X = U, - Y = V, - opts = {normalize = 0.9}, -} - --- pie chart demo: -local X = torch.DoubleTensor{19, 26, 55} -local legend = {'Residential', 'Non-Residential', 'Utility'} -plot:pie{ - X = X, - opts = {legend = legend}, -} - --- svg rendering demo: -local svgstr = [[ - - - Sorry, your browser does not support inline SVG. - -]] -plot:svg{ - svgstr = svgstr, - opts = { - title = 'Example of SVG Rendering', - }, -} - --- mesh plot demo: -local X = torch.DoubleTensor{ - {0, 0, 1, 1, 0, 0, 1, 1}, - {0, 1, 1, 0, 0, 1, 1, 0}, - {0, 0, 0, 0, 1, 1, 1, 1}, -}:t() -local Y = torch.DoubleTensor{ - {7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2}, - {3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3}, - {0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6}, -}:t() -plot:mesh{X = X, Y = Y, opts = {opacity = 0.5}} - --- video demo: -local video = torch.ByteTensor(256, 3, 128, 128) -for n = 1,video:size(1) do - video[n]:fill(n - 1) -end -local ok = pcall(plot.video, plot.video, {tensor = video}) -if not ok then print('Skipped video example') end - --- video demo: -local videofile = '/home/' .. os.getenv('USER') .. '/trailer.ogv' - -- NOTE: Download video from http://media.w3.org/2010/05/sintel/trailer.ogv -if paths.filep(videofile) then - local ok = pcall(plot.video, plot.video, {videofile = videofile}) - if not ok then print('Skipped video example') end -end - --- close text window: -plot:close{win = textwindow} - --- assert the window is closed -local exists = plot:win_exists{ - win = textwindow, -} -if exists then error("closed window still exists") end diff --git a/example/demo2.lua b/example/demo2.lua deleted file mode 100644 index 66c82747..00000000 --- a/example/demo2.lua +++ /dev/null @@ -1,196 +0,0 @@ ---[[ - -Copyright 2017-present, Facebook, Inc. -All rights reserved. - -This source code is licensed under the license found in the -LICENSE file in the root directory of this source tree. - -]]-- - --- load torchnet: -require 'torch' -local tnt = require 'torchnet' - --- intialize visdom Torch client: -local visdom = require 'visdom' -local plot = visdom{server = 'http://localhost', port = 8097} - --- use GPU or not: -local cmd = torch.CmdLine() -cmd:option('-usegpu', false, 'use gpu for training') -local config = cmd:parse(arg) -print(string.format('| running on %s...', config.usegpu and 'GPU' or 'CPU')) - --- function that creates a dataset iterator: -local function getIterator(mode, batchsize) - - -- load MNIST dataset: - local mnist = require 'mnist' - local dataset = mnist[mode .. 'dataset']() - dataset.data = dataset.data:reshape(dataset.data:size(1), - dataset.data:size(2) * dataset.data:size(3)):double():div(256) - - -- return dataset iterator: - return tnt.DatasetIterator{ - dataset = tnt.BatchDataset{ - batchsize = 128, - dataset = tnt.ListDataset{ - list = torch.range(1, dataset.data:size(1)):long(), - load = function(idx) - return { - input = dataset.data[idx], - target = torch.LongTensor{dataset.label[idx] + 1}, - } -- sample contains input and target - end, - } - } - } -end - --- get data iterators: -local maxepoch = 10 -local trainiterator = getIterator('train') -local testiterator = getIterator('test') -local trainsize = trainiterator:exec('size') -local testsize = testiterator:exec('size') - --- set up logistic regressor: -local net = nn.Sequential():add(nn.Linear(784, 10)) -local criterion = nn.CrossEntropyCriterion() - --- set up training engine and meters: -local engine = tnt.SGDEngine() -local meter = tnt.AverageValueMeter() -local clerr = tnt.ClassErrorMeter{topk = {1}} - --- reset meters at start of epoch: -local epoch = 0 -engine.hooks.onStartEpoch = function(state) - epoch = epoch + 1 - print(string.format('| epoch %d of %d...', epoch, maxepoch)) - meter:reset() - clerr:reset() -end - --- compute and plot training loss / error: -local trainlosshandle, trainerrhandle -local trainlosshist = torch.DoubleTensor(trainsize * maxepoch):fill(0) -local trainerrhist = torch.DoubleTensor(trainsize * maxepoch):fill(0) -local testlosshist = torch.DoubleTensor(maxepoch):fill(0) -local testerrhist = torch.DoubleTensor(maxepoch):fill(0) -engine.hooks.onForwardCriterion = function(state) - - -- update meters: - meter:add(state.criterion.output) - clerr:add(state.network.output, state.sample.target) - - -- update loss / error history: - local idx = state.training and (state.t + 1) or epoch - local losshist = state.training and trainlosshist or testlosshist - local errhist = state.training and trainerrhist or testerrhist - losshist[idx] = meter:value() - errhist[ idx] = clerr:value{k = 1} - - -- you need at least two points to draw a line: - if state.training and state.t >= 1 then - - -- plot training loss: - trainlosshandle = plot:line{ - Y = trainlosshist:narrow(1, 1, state.t + 1), - X = torch.range(1, state.t + 1), - win = trainlosshandle, -- keep handles around so we can update plot - opts = { - markers = false, - title = 'Training loss', - xlabel = 'Batch number', - ylabel = 'Loss value', - }, - } -- create new plot if it does not yet exist, otherwise, update plot - - -- plot training error: - trainerrhandle = plot:line{ - Y = trainerrhist:narrow(1, 1, state.t + 1), - X = torch.range(1, state.t + 1), - win = trainerrhandle, -- keep handles around so we can update plot - opts = { - markers = false, - title = 'Training error', - xlabel = 'Batch number', - ylabel = 'Classification error', - }, - } -- create new plot if it does not yet exist, otherwise, update plot - end -end - --- compute test loss at end of epoch: -local testlosshandle, testerrhandle -engine.hooks.onEndEpoch = function(state) - - -- measure test error: - meter:reset() - clerr:reset() - engine:test{ - network = net, - iterator = testiterator, - criterion = criterion, - } - - -- you need at least two points to draw a line: - if epoch >= 2 then - - -- plot test loss: - testlosshandle = plot:line{ - Y = testlosshist:narrow(1, 1, state.epoch), - X = torch.range(1, state.epoch), - win = testlosshandle, -- keep handles around so we can update plot - opts = { - markers = false, - title = 'Test loss', - xlabel = 'Epoch', - ylabel = 'Loss value', - } - } -- create new plot if it does not yet exist, otherwise, update plot - - -- plot test error: - testerrhandle = plot:line{ - Y = testerrhist:narrow(1, 1, state.epoch), - X = torch.range(1, state.epoch), - win = testerrhandle, -- keep handles around so we can update plot - opts = { - markers = false, - title = 'Test error', - xlabel = 'Epoch', - ylabel = 'Classification error', - } - } -- create new plot if it does not yet exist, otherwise, update plot - end -end - --- set up GPU training: -if config.usegpu then - - -- copy model to GPU: - require 'cunn' - net = net:cuda() - criterion = criterion:cuda() - - -- copy sample to GPU buffer: - local igpu, tgpu = torch.CudaTensor(), torch.CudaTensor() - engine.hooks.onSample = function(state) - igpu:resize(state.sample.input:size() ):copy(state.sample.input) - tgpu:resize(state.sample.target:size()):copy(state.sample.target) - state.sample.input = igpu - state.sample.target = tgpu - end -- alternatively, this logic can be implemented via a TransformDataset -end - --- train the model: -engine:train{ - network = net, - iterator = trainiterator, - criterion = criterion, - lr = 0.2, - maxepoch = maxepoch, -} -print('| done.') diff --git a/py/visdom/server.py b/py/visdom/server.py index e47d1972..e82d3941 100644 --- a/py/visdom/server.py +++ b/py/visdom/server.py @@ -948,6 +948,7 @@ def update(p, args): selected_not_neg = max(0, selected) selected_exists = min(len(p['content'])-1, selected_not_neg) p['selected'] = selected_exists + return p pdata = p['content']['data'] From a09486c173661c40eb44c2cc955ff19438e8344a Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Fri, 20 Sep 2019 16:54:57 -0400 Subject: [PATCH 3/9] Removing torchfile --- py/visdom/__init__.py | 18 ------------------ setup.py | 1 - 2 files changed, 19 deletions(-) diff --git a/py/visdom/__init__.py b/py/visdom/__init__.py index 802fd6e0..2491843b 100644 --- a/py/visdom/__init__.py +++ b/py/visdom/__init__.py @@ -85,11 +85,6 @@ def do_tsne(X): except Exception: __version__ = 'no_version_file' -try: - import torchfile # type: ignore -except BaseException: - from . import torchfile - try: raise ConnectionError() except NameError: # python 2 doesn't have ConnectionError @@ -134,19 +129,6 @@ def nan2none(l): return l -def from_t7(t, b64=False): - if b64: - t = base64.b64decode(t) - - with open('/dev/shm/t7', 'wb') as ff: - ff.write(t) - ff.close() - - sf = open('/dev/shm/t7', 'rb') - - return torchfile.T7Reader(sf).read_obj() - - def loadfile(filename): assert os.path.isfile(filename), 'could not find file %s' % filename fileobj = open(filename, 'rb') diff --git a/setup.py b/setup.py index 3a69b3d3..73c79470 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,6 @@ def get_dist(pkgname): 'pyzmq', 'six', 'jsonpatch', - 'torchfile', 'websocket-client', ] pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow' From f3a1b9a220387ffa01a10d52954f2e2c286420c3 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Thu, 19 Sep 2019 23:29:51 -0400 Subject: [PATCH 4/9] Removing lua torch support --- README.md | 37 +++++++---------------------------- py/visdom/__init__.py | 22 ++++++++++----------- py/visdom/server.py | 45 ++++++------------------------------------- th/init.lua | 4 ++++ 4 files changed, 28 insertions(+), 80 deletions(-) diff --git a/README.md b/README.md index c6d27e0a..59a03097 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ![visdom_big](https://lh3.googleusercontent.com/-bqH9UXCw-BE/WL2UsdrrbAI/AAAAAAAAnYc/emrxwCmnrW4_CLTyyUttB0SYRJ-i4CCiQCLcB/s0/Screen+Shot+2017-03-06+at+10.51.02+AM.png"visdom_big") -A flexible tool for creating, organizing, and sharing visualizations of live, rich data. Supports Torch and Numpy. +A flexible tool for creating, organizing, and sharing visualizations of live, rich data. Supports Python. * [Overview](#overview) * [Concepts](#concepts) @@ -142,17 +142,13 @@ Using the view dropdown it is possible to select previously saved views, restori ## Setup -Requires Python 3 (and optionally Torch7) +Requires Python 3 ```bash # Install Python server and client from pip # (STABLE VERSION, NOT ALL CURRENT FEATURES ARE SUPPORTED) pip install visdom -# Install Torch client -# (STABLE VERSION, NOT ALL CURRENT FEATURES ARE SUPPORTED) -luarocks install visdom - ``` ```bash @@ -161,9 +157,6 @@ pip install -e . # If the above runs into issues, you can try the below easy_install . -# Install Torch client from source (from th directory) -luarocks make - ``` ## Usage @@ -216,29 +209,10 @@ vis.text('Hello, world!') vis.image(np.ones((3, 10, 10))) ``` -#### Torch example -```lua -require 'image' -vis = require 'visdom'() -vis:text{text = 'Hello, world!'} -vis:image{img = image.fabio()} -``` - -Some users have reported issues when connecting Lua clients to the Visdom server. -A potential work-around may be to switch off IPv6: -``` -vis = require 'visdom'() -vis.ipv6 = false -- switches off IPv6 -vis:text{text = 'Hello, world!'} -``` - - ### Demos ```bash python example/demo.py -th example/demo1.lua -th example/demo2.lua ``` @@ -261,7 +235,7 @@ The python visdom client takes a few options: - `proxies`: Dictionary mapping protocol to the URL of the proxy (e.g. {`http`: `foo.bar:3128`}) to be used on each Request. (default: `None`) - `offline`: Flag to run visdom in offline mode, where all requests are logged to file rather than to the server. Requires `log_to_filename` is set. In offline mode, all visdom commands that don't create or update plots will simply return `True`. (default: `False`) -Other options are either currently unused (endpoint, ipv6) or used for internal functionality (send allows the visdom server to replicate events for the lua client). +Other options are either currently unused (endpoint, ipv6) or used for internal functionality. ### Basics Visdom offers the following basic visualization functions: @@ -440,7 +414,7 @@ We currently assume that there are no more than 10 unique labels, in the future From the UI you can also draw a lasso around a subset of features. This will rerun the t-SNE visualization on the selected subset. #### vis.save -This function saves the `envs` that are alive on the visdom server. It takes input a list (in python) or table (in lua) of env ids to be saved. +This function saves the `envs` that are alive on the visdom server. It takes input a list of env ids to be saved. ### Plotting Further details on the wrapped plotting functions are given below. @@ -706,6 +680,9 @@ Arguments: ## License visdom is Creative Commons Attribution-NonCommercial 4.0 International Public licensed, as found in the LICENSE file. +## Note on Lua Torch Support +Support for Lua Torch was deprecated following `v0.1.8.4`. If you'd like to use torch support, you'll need to download that release. You can follow the usage instructions there, but it is no longer officially supported. + ## Contributing See guidelines for contributing [here.](./CONTRIBUTING.md) diff --git a/py/visdom/__init__.py b/py/visdom/__init__.py index cf800df7..d4330eb4 100644 --- a/py/visdom/__init__.py +++ b/py/visdom/__init__.py @@ -316,11 +316,11 @@ def _assert_opts(opts): if opts.get('columnnames'): assert isinstance(opts.get('columnnames'), list), \ - 'columnnames should be a table with column names' + 'columnnames should be a list with column names' if opts.get('rownames'): assert isinstance(opts.get('rownames'), list), \ - 'rownames should be a table with row names' + 'rownames should be a list with row names' if opts.get('jpgquality'): assert isnum(opts.get('jpgquality')), \ @@ -725,7 +725,7 @@ def _send(self, msg, endpoint='events', quiet=False, from_log=False, create=True def save(self, envs): """ This function allows the user to save envs that are alive on the - Tornado server. The envs can be specified as a table (list) of env ids. + Tornado server. The envs can be specified as a list of env ids. """ assert isinstance(envs, list), 'envs should be a list' if len(envs) > 0: @@ -1456,7 +1456,7 @@ def scatter(self, X, Y=None, win=None, env=None, opts=None, update=None, - `opts.markercolor` : marker color (`np.array`; default = `None`) - `opts.dash` : dash type (`np.array`; default = 'solid'`) - `opts.textlabels` : text label for each point (`list`: default = `None`) - - `opts.legend` : `table` containing legend names + - `opts.legend` : `list` containing legend names """ if update == 'remove': assert win is not None @@ -1655,7 +1655,7 @@ def line(self, Y, X=None, win=None, env=None, opts=None, update=None, - `opts.markersize` : marker size (`number`; default = `'10'`) - `opts.linecolor` : line colors (`np.array`; default = None) - `opts.dash` : line dash type (`np.array`; default = None) - - `opts.legend` : `table` containing legend names + - `opts.legend` : `list` containing legend names If `update` is specified, the figure will be updated without creating a new plot -- this can be used for efficient updating. @@ -1715,8 +1715,8 @@ def heatmap(self, X, win=None, env=None, opts=None): - `opts.colormap`: colormap (`string`; default = `'Viridis'`) - `opts.xmin` : clip minimum value (`number`; default = `X:min()`) - `opts.xmax` : clip maximum value (`number`; default = `X:max()`) - - `opts.columnnames`: `table` containing x-axis labels - - `opts.rownames`: `table` containing y-axis labels + - `opts.columnnames`: `list` containing x-axis labels + - `opts.rownames`: `list` containing y-axis labels - `opts.nancolor`: if not None, color for plotting nan (`string`; default = `None`) """ @@ -1781,9 +1781,9 @@ def bar(self, X, Y=None, win=None, env=None, opts=None): The following plot-specific `opts` are currently supported: - - `opts.rownames`: `table` containing x-axis labels + - `opts.rownames`: `list` containing x-axis labels - `opts.stacked` : stack multiple columns in `X` - - `opts.legend` : `table` containing legend labels + - `opts.legend` : `list` containing legend labels """ X = np.squeeze(X) assert X.ndim == 1 or X.ndim == 2, 'X should be one or two-dimensional' @@ -2081,7 +2081,7 @@ def stem(self, X, Y=None, win=None, env=None, opts=None): The following `opts` are supported: - `opts.colormap`: colormap (`string`; default = `'Viridis'`) - - `opts.legend` : `table` containing legend names + - `opts.legend` : `list` containing legend names """ X = np.squeeze(X) @@ -2124,7 +2124,7 @@ def pie(self, X, win=None, env=None, opts=None): The following `opts` are supported: - - `opts.legend`: `table` containing legend names + - `opts.legend`: `list` containing legend names """ X = np.squeeze(X) diff --git a/py/visdom/server.py b/py/visdom/server.py index 4b07bb82..988334b6 100644 --- a/py/visdom/server.py +++ b/py/visdom/server.py @@ -32,7 +32,6 @@ # for python 3.7 and below from collections import Mapping, Sequence -import visdom from zmq.eventloop import ioloop ioloop.install() # Needs to happen before any tornado imports! @@ -804,18 +803,6 @@ def register_window(self, p, eid): self.write(p['id']) -def unpack_lua(req_args): - if req_args['is_table']: - if isinstance(req_args['val'], dict): - return {k: unpack_lua(v) for (k, v) in req_args['val'].items()} - else: - return [unpack_lua(v) for v in req_args['val']] - elif req_args['is_tensor']: - return visdom.from_t7(req_args['val'], b64=True) - else: - return req_args['val'] - - class PostHandler(BaseHandler): def initialize(self, app): self.state = app.state @@ -824,9 +811,6 @@ def initialize(self, app): self.port = app.port self.env_path = app.env_path self.login_enabled = app.login_enabled - self.vis = visdom.Visdom( - port=self.port, send=False, use_incoming_socket=False - ) self.handlers = { 'update': UpdateHandler, 'save': SaveHandler, @@ -835,19 +819,6 @@ def initialize(self, app): 'delete_env': DeleteEnvHandler, } - def func(self, req): - args, kwargs = req['args'], req.get('kwargs', {}) - - args = (unpack_lua(a) for a in args) - - for k in kwargs: - v = kwargs[k] - kwargs[k] = unpack_lua(v) - - func = getattr(self.vis, req['func']) - - return func(*args, **kwargs) - @check_auth def post(self): req = tornado.escape.json_decode( @@ -855,16 +826,12 @@ def post(self): ) if req.get('func') is not None: - try: - req, endpoint = self.func(req) - if (endpoint != 'events'): - # Process the request using the proper handler - self.handlers[endpoint].wrap_func(self, req) - return - except Exception: - # get traceback and send it back - print(traceback.format_exc()) - return self.write(traceback.format_exc()) + raise Exception( + 'Support for Lua Torch was deprecated following `v0.1.8.4`. ' + "If you'd like to use torch support, you'll need to download " + "that release. You can follow the usage instructions there, " + "but it is no longer officially supported." + ) eid = extract_eid(req) p = window(req) diff --git a/th/init.lua b/th/init.lua index 60f4119f..68a2ab87 100644 --- a/th/init.lua +++ b/th/init.lua @@ -29,6 +29,10 @@ M.__init = argcheck{ visualization server that wraps plot.ly to show scalable, high-quality visualizations in the browser. + Note: The lua Torch client for visdom was deprecated after visdom + v0.1.8.4, so if you'd like to use visdom for torch, you'll have to + download that specific tag of visdom from the github. + The server can be started with the `server.py` script. The server defaults to port 8097. When the server is running on `domain.com:8097`, then visit that web address in your browser to see the visualization desktop. From fff7bd08b685e364c70af7aea9a00c30f0cc049a Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Fri, 20 Sep 2019 00:45:57 -0400 Subject: [PATCH 5/9] Removing demo, small merge fix' --- example/demo1.lua | 333 ---------------------------------------------- example/demo2.lua | 196 --------------------------- 2 files changed, 529 deletions(-) delete mode 100644 example/demo1.lua delete mode 100644 example/demo2.lua diff --git a/example/demo1.lua b/example/demo1.lua deleted file mode 100644 index c9a67b62..00000000 --- a/example/demo1.lua +++ /dev/null @@ -1,333 +0,0 @@ ---[[ - -Copyright 2017-present, Facebook, Inc. -All rights reserved. - -This source code is licensed under the license found in the -LICENSE file in the root directory of this source tree. - -]]-- - --- dependencies: -require 'torch' -require 'image' -local paths = require 'paths' - --- intialize visdom Torch client: -local visdom = require 'visdom' -local plot = visdom{server = 'http://localhost', port = 8097} -if not plot:check_connection() then - error('Could not connect, please ensure the visdom server is running') -end - --- text box demo: -local textwindow = plot:text{ - text = 'Hello, world! If I\'m still open, close failed' -} -local updatetextwindow = plot:text{ - text = 'Hello, world! If I don\'t have another line, update text failed.' -} -plot:text{text = 'Here\'s another line', win = updatetextwindow, append = true} - -plot:py_func{func='text', args={'Hello, world!'}} - --- image demo: -plot:image{ - img = image.fabio(), - opts = { - title = 'Fabio', - caption = 'Hello, I am Fabio ;)', - } -} - --- images demo: -plot:images{ - table = {torch.zeros(3, 200, 200) + 0.1, torch.zeros(3, 200, 200) + 0.2}, - opts = { - caption = 'I was a table of tensors...', - } -} - --- images demo: -plot:images{ - tensor = torch.randn(6, 3, 200, 200), - opts = { - caption = 'I was a 4D tensor...', - } -} - --- scatter plot demos: -plot:scatter{ - X = torch.randn(100, 2), - Y = torch.randn(100):gt(0):add(1):long(), - opts = { - legend = {'Apples', 'Pears'}, - xtickmin = -5, - xtickmax = 5, - xtickstep = .5, - ytickmin = -5, - ytickmax = 5, - ytickstep = .5, - markersymbol = 'cross-thin-open', - } -} -plot:scatter{ - X = torch.randn(100, 3), - Y = torch.randn(100):gt(0):add(1):long(), - opts = { - markersize = 5, - legend = {'Men', 'Women'}, - }, -} - --- 2D scatterplot with custom intensities (red channel): -local id = plot:scatter{ - X = torch.randn(255, 2), - opts = { - markersize = 10, - markercolor = torch.zeros(255):random(0, 255), - }, -} - --- check if win_exists works -local exists = plot:win_exists{ - win = id, -} -if not exists then error("created window doesn't exist") end - - -plot:line{ -- add new trace to scatter plot - X = torch.randn(255), - Y = torch.randn(255), - win = id, - name = 'new trace', - update = 'append', -} - --- 2D scatter plot with custom colors: -plot:scatter{ - X = torch.randn(255, 2), - opts = { - markersize = 10, - markercolor = torch.zeros(255, 3):random(0, 255), - }, -} - --- 2D scatter plot with custom colors per label: -plot:scatter{ - X = torch.randn(255, 2), - Y = torch.randn(255):gt(0):add(1):long(), -- two labels - opts = { - markersize = 10, - markercolor = torch.zeros(2, 3):random(0, 255), - }, -} - --- bar plot demos: -plot:bar{ - X = torch.randn(20) -} -plot:bar{ - X = torch.randn(5, 3):abs(), - opts = { - stacked = true, - legend = {'Facebook', 'Google', 'Twitter'}, - rownames = {'2012', '2013', '2014', '2015', '2016'}, - }, -} -plot:bar{ - X = torch.randn(20, 3), - opts = { - stacked = false, - legend = {'The Netherlands', 'France', 'United States'}, - }, -} - --- histogram demo: -plot:histogram{ - X = torch.randn(10000), - opts = {numbins = 20}, -} - --- heatmap demo: -local X = torch.cmul(torch.range(1, 10):reshape(1, 10):expand(5, 10), - torch.range(1, 5):reshape(5, 1):expand(5, 10)) -plot:heatmap{ - X = X, - opts = { - columnnames = {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'}, - rownames = {'y1', 'y2', 'y3', 'y4', 'y5'}, - colormap = 'Electric', - }, -} - --- contour plot demo: -local x = torch.range(1, 100):reshape(1, 100):expand(100, 100) -local y = torch.range(1, 100):reshape(100, 1):expand(100, 100) -local X = torch.add(x, -50):pow(2):add( - torch.add(y, -50):pow(2) -):div(-math.pow(20, 2)):exp() -plot:contour{X = X, opts = {colormap = 'Viridis'}} - --- surface plot demo: -plot:surf{X = X, opts = {colormap = 'Hot'}} - --- line plot demos: -local Y = torch.range(-4, 4, 0.05) -plot:line{ - Y = torch.cat(torch.cmul(Y, Y), torch.sqrt(Y + 5), 2), - X = torch.cat(Y, Y, 2), - opts = {markers = false} -} - -local id = plot:line{ - Y = torch.cat(torch.range(0, 10), torch.range(0, 10) + 5, 2), - X = torch.cat(torch.range(0, 10), torch.range(0, 10), 2), - opts = {markers = false} -} - -plot:py_func{ - func='line', - args = {torch.randn(10)}, - kwargs = {opts = {title = 'This is lua through python'}} -} - --- update trace demos: -plot:line{ - X = torch.cat(torch.range(11, 20), torch.range(11, 20), 2), - Y = torch.cat(torch.range(11, 20), torch.range(5, 14) * 2 + 5, 2), - win = id, - update = 'append', -} - -plot:line{ - X = torch.range(1, 10), - Y = torch.range(1, 10), - win = id, - name = '3', - update = 'append', -} - -plot:line{ - X = torch.range(1, 10), - Y = torch.range(11, 20), - win = id, - name = '4', - update = 'append', -} - --- stacked line plot demo: -local Y = torch.range(0, 4, 0.02) -plot:line{ - Y = torch.cat(torch.sqrt(Y), torch.sqrt(Y):add(2), 2), - X = torch.cat(Y, Y, 2), - opts = { - fillarea = true, - legend = false, - width = 400, - height = 400, - xlabel = 'Time', - ylabel = 'Volume', - ytype = 'log', - title = 'Stacked area plot', - marginleft = 30, - marginright = 30, - marginbottom = 80, - margintop = 30, - }, -} - --- boxplot demo: -local X = torch.randn(100, 2) -X:narrow(2, 2, 1):add(2) -plot:boxplot{ - X = X, - opts = { - legend = {'Men', 'Women'}, - }, -} - --- stem plot demo: -local Y = torch.range(0, 2 * math.pi, (2 * math.pi) / 70) -local X = torch.cat(torch.sin(Y), torch.cos(Y), 2) -plot:stem{ - X = X, - Y = Y, - opts = { - legend = {'Sine', 'Cosine'}, - }, -} - --- quiver demo: -local X = torch.range(0, 2, 0.2) -local Y = torch.range(0, 2, 0.2) -X = X:resize(1, X:nElement()):expand(X:nElement(), X:nElement()) -Y = Y:resize(Y:nElement(), 1):expand(Y:nElement(), Y:nElement()) -local U = torch.cos(X):cmul(Y) -local V = torch.sin(X):cmul(Y) -plot:quiver{ - X = U, - Y = V, - opts = {normalize = 0.9}, -} - --- pie chart demo: -local X = torch.DoubleTensor{19, 26, 55} -local legend = {'Residential', 'Non-Residential', 'Utility'} -plot:pie{ - X = X, - opts = {legend = legend}, -} - --- svg rendering demo: -local svgstr = [[ - - - Sorry, your browser does not support inline SVG. - -]] -plot:svg{ - svgstr = svgstr, - opts = { - title = 'Example of SVG Rendering', - }, -} - --- mesh plot demo: -local X = torch.DoubleTensor{ - {0, 0, 1, 1, 0, 0, 1, 1}, - {0, 1, 1, 0, 0, 1, 1, 0}, - {0, 0, 0, 0, 1, 1, 1, 1}, -}:t() -local Y = torch.DoubleTensor{ - {7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2}, - {3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3}, - {0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6}, -}:t() -plot:mesh{X = X, Y = Y, opts = {opacity = 0.5}} - --- video demo: -local video = torch.ByteTensor(256, 3, 128, 128) -for n = 1,video:size(1) do - video[n]:fill(n - 1) -end -local ok = pcall(plot.video, plot.video, {tensor = video}) -if not ok then print('Skipped video example') end - --- video demo: -local videofile = '/home/' .. os.getenv('USER') .. '/trailer.ogv' - -- NOTE: Download video from http://media.w3.org/2010/05/sintel/trailer.ogv -if paths.filep(videofile) then - local ok = pcall(plot.video, plot.video, {videofile = videofile}) - if not ok then print('Skipped video example') end -end - --- close text window: -plot:close{win = textwindow} - --- assert the window is closed -local exists = plot:win_exists{ - win = textwindow, -} -if exists then error("closed window still exists") end diff --git a/example/demo2.lua b/example/demo2.lua deleted file mode 100644 index 66c82747..00000000 --- a/example/demo2.lua +++ /dev/null @@ -1,196 +0,0 @@ ---[[ - -Copyright 2017-present, Facebook, Inc. -All rights reserved. - -This source code is licensed under the license found in the -LICENSE file in the root directory of this source tree. - -]]-- - --- load torchnet: -require 'torch' -local tnt = require 'torchnet' - --- intialize visdom Torch client: -local visdom = require 'visdom' -local plot = visdom{server = 'http://localhost', port = 8097} - --- use GPU or not: -local cmd = torch.CmdLine() -cmd:option('-usegpu', false, 'use gpu for training') -local config = cmd:parse(arg) -print(string.format('| running on %s...', config.usegpu and 'GPU' or 'CPU')) - --- function that creates a dataset iterator: -local function getIterator(mode, batchsize) - - -- load MNIST dataset: - local mnist = require 'mnist' - local dataset = mnist[mode .. 'dataset']() - dataset.data = dataset.data:reshape(dataset.data:size(1), - dataset.data:size(2) * dataset.data:size(3)):double():div(256) - - -- return dataset iterator: - return tnt.DatasetIterator{ - dataset = tnt.BatchDataset{ - batchsize = 128, - dataset = tnt.ListDataset{ - list = torch.range(1, dataset.data:size(1)):long(), - load = function(idx) - return { - input = dataset.data[idx], - target = torch.LongTensor{dataset.label[idx] + 1}, - } -- sample contains input and target - end, - } - } - } -end - --- get data iterators: -local maxepoch = 10 -local trainiterator = getIterator('train') -local testiterator = getIterator('test') -local trainsize = trainiterator:exec('size') -local testsize = testiterator:exec('size') - --- set up logistic regressor: -local net = nn.Sequential():add(nn.Linear(784, 10)) -local criterion = nn.CrossEntropyCriterion() - --- set up training engine and meters: -local engine = tnt.SGDEngine() -local meter = tnt.AverageValueMeter() -local clerr = tnt.ClassErrorMeter{topk = {1}} - --- reset meters at start of epoch: -local epoch = 0 -engine.hooks.onStartEpoch = function(state) - epoch = epoch + 1 - print(string.format('| epoch %d of %d...', epoch, maxepoch)) - meter:reset() - clerr:reset() -end - --- compute and plot training loss / error: -local trainlosshandle, trainerrhandle -local trainlosshist = torch.DoubleTensor(trainsize * maxepoch):fill(0) -local trainerrhist = torch.DoubleTensor(trainsize * maxepoch):fill(0) -local testlosshist = torch.DoubleTensor(maxepoch):fill(0) -local testerrhist = torch.DoubleTensor(maxepoch):fill(0) -engine.hooks.onForwardCriterion = function(state) - - -- update meters: - meter:add(state.criterion.output) - clerr:add(state.network.output, state.sample.target) - - -- update loss / error history: - local idx = state.training and (state.t + 1) or epoch - local losshist = state.training and trainlosshist or testlosshist - local errhist = state.training and trainerrhist or testerrhist - losshist[idx] = meter:value() - errhist[ idx] = clerr:value{k = 1} - - -- you need at least two points to draw a line: - if state.training and state.t >= 1 then - - -- plot training loss: - trainlosshandle = plot:line{ - Y = trainlosshist:narrow(1, 1, state.t + 1), - X = torch.range(1, state.t + 1), - win = trainlosshandle, -- keep handles around so we can update plot - opts = { - markers = false, - title = 'Training loss', - xlabel = 'Batch number', - ylabel = 'Loss value', - }, - } -- create new plot if it does not yet exist, otherwise, update plot - - -- plot training error: - trainerrhandle = plot:line{ - Y = trainerrhist:narrow(1, 1, state.t + 1), - X = torch.range(1, state.t + 1), - win = trainerrhandle, -- keep handles around so we can update plot - opts = { - markers = false, - title = 'Training error', - xlabel = 'Batch number', - ylabel = 'Classification error', - }, - } -- create new plot if it does not yet exist, otherwise, update plot - end -end - --- compute test loss at end of epoch: -local testlosshandle, testerrhandle -engine.hooks.onEndEpoch = function(state) - - -- measure test error: - meter:reset() - clerr:reset() - engine:test{ - network = net, - iterator = testiterator, - criterion = criterion, - } - - -- you need at least two points to draw a line: - if epoch >= 2 then - - -- plot test loss: - testlosshandle = plot:line{ - Y = testlosshist:narrow(1, 1, state.epoch), - X = torch.range(1, state.epoch), - win = testlosshandle, -- keep handles around so we can update plot - opts = { - markers = false, - title = 'Test loss', - xlabel = 'Epoch', - ylabel = 'Loss value', - } - } -- create new plot if it does not yet exist, otherwise, update plot - - -- plot test error: - testerrhandle = plot:line{ - Y = testerrhist:narrow(1, 1, state.epoch), - X = torch.range(1, state.epoch), - win = testerrhandle, -- keep handles around so we can update plot - opts = { - markers = false, - title = 'Test error', - xlabel = 'Epoch', - ylabel = 'Classification error', - } - } -- create new plot if it does not yet exist, otherwise, update plot - end -end - --- set up GPU training: -if config.usegpu then - - -- copy model to GPU: - require 'cunn' - net = net:cuda() - criterion = criterion:cuda() - - -- copy sample to GPU buffer: - local igpu, tgpu = torch.CudaTensor(), torch.CudaTensor() - engine.hooks.onSample = function(state) - igpu:resize(state.sample.input:size() ):copy(state.sample.input) - tgpu:resize(state.sample.target:size()):copy(state.sample.target) - state.sample.input = igpu - state.sample.target = tgpu - end -- alternatively, this logic can be implemented via a TransformDataset -end - --- train the model: -engine:train{ - network = net, - iterator = trainiterator, - criterion = criterion, - lr = 0.2, - maxepoch = maxepoch, -} -print('| done.') From 4ab9675a228f1eb9e2477df52e1aad682f1a545d Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Fri, 20 Sep 2019 16:54:57 -0400 Subject: [PATCH 6/9] Removing torchfile --- py/visdom/__init__.py | 18 ------------------ setup.py | 1 - 2 files changed, 19 deletions(-) diff --git a/py/visdom/__init__.py b/py/visdom/__init__.py index d4330eb4..9f312317 100644 --- a/py/visdom/__init__.py +++ b/py/visdom/__init__.py @@ -82,11 +82,6 @@ def do_tsne(X): except Exception: __version__ = 'no_version_file' -try: - import torchfile # type: ignore -except BaseException: - from . import torchfile - logging.getLogger('requests').setLevel(logging.CRITICAL) logging.getLogger('urllib3').setLevel(logging.CRITICAL) logger = logging.getLogger(__name__) @@ -123,19 +118,6 @@ def nan2none(l): return l -def from_t7(t, b64=False): - if b64: - t = base64.b64decode(t) - - with open('/dev/shm/t7', 'wb') as ff: - ff.write(t) - ff.close() - - sf = open('/dev/shm/t7', 'rb') - - return torchfile.T7Reader(sf).read_obj() - - def loadfile(filename): assert os.path.isfile(filename), 'could not find file %s' % filename fileobj = open(filename, 'rb') diff --git a/setup.py b/setup.py index 3a69b3d3..73c79470 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,6 @@ def get_dist(pkgname): 'pyzmq', 'six', 'jsonpatch', - 'torchfile', 'websocket-client', ] pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow' From ce8082091bfae63fd0798e35c327314f40821a2a Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Sat, 21 Sep 2019 14:28:23 -0400 Subject: [PATCH 7/9] rebase --- py/visdom/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/visdom/__init__.py b/py/visdom/__init__.py index 9f312317..d69ae40a 100644 --- a/py/visdom/__init__.py +++ b/py/visdom/__init__.py @@ -339,8 +339,8 @@ def _to_numpy(a): if isinstance(a, torch.autograd.Variable): # For PyTorch < 0.4 comptability. warnings.warn( - "Support for versions of PyTorch less than 0.4 is deprecated and " - "will eventually be removed.", DeprecationWarning) + "Support for versions of PyTorch less than 0.4 is deprecated " + "and will eventually be removed.", DeprecationWarning) a = a.data for kind in torch_types: if isinstance(a, kind): From 34f3ea9533347899c5e7f9dafb39ff464f8fee22 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Sun, 22 Sep 2019 23:35:49 -0400 Subject: [PATCH 8/9] Refactoring server.py into more intentional files --- py/visdom/__init__.py | 2 + py/visdom/server/__main__.py | 15 + py/visdom/server/app.py | 159 ++++ py/visdom/server/build.py | 129 +++ py/visdom/server/defaults.py | 14 + .../handlers/all_handlers.py} | 844 +----------------- py/visdom/server/handlers/base_handlers.py | 103 +++ py/visdom/server/run_server.py | 171 ++++ py/visdom/utils/server_utils.py | 414 +++++++++ py/visdom/utils/shared_utils.py | 62 ++ setup.py | 3 +- 11 files changed, 1076 insertions(+), 840 deletions(-) create mode 100644 py/visdom/server/__main__.py create mode 100644 py/visdom/server/app.py create mode 100644 py/visdom/server/build.py create mode 100644 py/visdom/server/defaults.py rename py/visdom/{server.py => server/handlers/all_handlers.py} (53%) create mode 100644 py/visdom/server/handlers/base_handlers.py create mode 100644 py/visdom/server/run_server.py create mode 100644 py/visdom/utils/server_utils.py create mode 100644 py/visdom/utils/shared_utils.py diff --git a/py/visdom/__init__.py b/py/visdom/__init__.py index d69ae40a..5e7c6dee 100644 --- a/py/visdom/__init__.py +++ b/py/visdom/__init__.py @@ -6,6 +6,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from visdom.utils.shared_utils import get_new_window_id +from visdom import server import os.path import requests import traceback diff --git a/py/visdom/server/__main__.py b/py/visdom/server/__main__.py new file mode 100644 index 00000000..aa1195d5 --- /dev/null +++ b/py/visdom/server/__main__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +assert sys.version_info[0] >= 3, 'To use visdom with python 2, downgrade to v0.1.8.9' + +if __name__ == "__main__": + from visdom.server.run_server import download_scripts_and_run + download_scripts_and_run() diff --git a/py/visdom/server/app.py b/py/visdom/server/app.py new file mode 100644 index 00000000..955ab74f --- /dev/null +++ b/py/visdom/server/app.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Main application class that pulls handlers together and maintains +all of the required state about the currently running server. +""" + +from visdom.utils.shared_utils import warn_once, ensure_dir_exists, get_visdom_path_to + +from visdom.utils.server_utils import ( + serialize_env, +) + +# TODO replace this next +from visdom.server.handlers.all_handlers import * + +import copy +import hashlib +import logging +import os +import time + +import tornado.web # noqa E402: gotta install ioloop first +import tornado.escape # noqa E402: gotta install ioloop first + +LAYOUT_FILE = 'layouts.json' + +tornado_settings = { + "autoescape": None, + "debug": "/dbg/" in __file__, + "static_path": get_visdom_path_to('static'), + "template_path": get_visdom_path_to('static'), + "compiled_template_cache": False +} + +class Application(tornado.web.Application): + def __init__(self, port=DEFAULT_PORT, base_url='', + env_path=DEFAULT_ENV_PATH, readonly=False, + user_credential=None, use_frontend_client_polling=False): + self.env_path = env_path + self.state = self.load_state() + self.layouts = self.load_layouts() + self.subs = {} + self.sources = {} + self.port = port + self.base_url = base_url + self.readonly = readonly + self.user_credential = user_credential + self.login_enabled = False + self.last_access = time.time() + self.wrap_socket = use_frontend_client_polling + + if user_credential: + self.login_enabled = True + with open(DEFAULT_ENV_PATH + "COOKIE_SECRET", "r") as fn: + tornado_settings["cookie_secret"] = fn.read() + + tornado_settings['static_url_prefix'] = self.base_url + "/static/" + tornado_settings['debug'] = True + handlers = [ + (r"%s/events" % self.base_url, PostHandler, {'app': self}), + (r"%s/update" % self.base_url, UpdateHandler, {'app': self}), + (r"%s/close" % self.base_url, CloseHandler, {'app': self}), + (r"%s/socket" % self.base_url, SocketHandler, {'app': self}), + (r"%s/socket_wrap" % self.base_url, SocketWrap, {'app': self}), + (r"%s/vis_socket" % self.base_url, + VisSocketHandler, {'app': self}), + (r"%s/vis_socket_wrap" % self.base_url, + VisSocketWrap, {'app': self}), + (r"%s/env/(.*)" % self.base_url, EnvHandler, {'app': self}), + (r"%s/compare/(.*)" % self.base_url, + CompareHandler, {'app': self}), + (r"%s/save" % self.base_url, SaveHandler, {'app': self}), + (r"%s/error/(.*)" % self.base_url, ErrorHandler, {'app': self}), + (r"%s/win_exists" % self.base_url, ExistsHandler, {'app': self}), + (r"%s/win_data" % self.base_url, DataHandler, {'app': self}), + (r"%s/delete_env" % self.base_url, + DeleteEnvHandler, {'app': self}), + (r"%s/win_hash" % self.base_url, HashHandler, {'app': self}), + (r"%s/env_state" % self.base_url, EnvStateHandler, {'app': self}), + (r"%s/fork_env" % self.base_url, ForkEnvHandler, {'app': self}), + (r"%s(.*)" % self.base_url, IndexHandler, {'app': self}), + ] + super(Application, self).__init__(handlers, **tornado_settings) + + def get_last_access(self): + if len(self.subs) > 0 or len(self.sources) > 0: + # update the last access time to now, as someone + # is currently connected to the server + self.last_access = time.time() + return self.last_access + + def save_layouts(self): + if self.env_path is None: + warn_once( + 'Saving and loading to disk has no effect when running with ' + 'env_path=None.', + RuntimeWarning + ) + return + layout_filepath = os.path.join(self.env_path, 'view', LAYOUT_FILE) + with open(layout_filepath, 'w') as fn: + fn.write(self.layouts) + + def load_layouts(self): + if self.env_path is None: + warn_once( + 'Saving and loading to disk has no effect when running with ' + 'env_path=None.', + RuntimeWarning + ) + return "" + layout_filepath = os.path.join(self.env_path, 'view', LAYOUT_FILE) + ensure_dir_exists(layout_filepath) + if os.path.isfile(layout_filepath): + with open(layout_filepath, 'r') as fn: + return fn.read() + else: + return "" + + def load_state(self): + state = {} + env_path = self.env_path + if env_path is None: + warn_once( + 'Saving and loading to disk has no effect when running with ' + 'env_path=None.', + RuntimeWarning + ) + return {'main': {'jsons': {}, 'reload': {}}} + ensure_dir_exists(env_path) + env_jsons = [i for i in os.listdir(env_path) if '.json' in i] + + for env_json in env_jsons: + env_path_file = os.path.join(env_path, env_json) + try: + with open(env_path_file, 'r') as fn: + env_data = tornado.escape.json_decode(fn.read()) + except Exception as e: + logging.warn( + "Failed loading environment json: {} - {}".format( + env_path_file, repr(e))) + continue + + eid = env_json.replace('.json', '') + state[eid] = {'jsons': env_data['jsons'], + 'reload': env_data['reload']} + + if 'main' not in state and 'main.json' not in env_jsons: + state['main'] = {'jsons': {}, 'reload': {}} + serialize_env(state, ['main'], env_path=self.env_path) + + return state diff --git a/py/visdom/server/build.py b/py/visdom/server/build.py new file mode 100644 index 00000000..9ca41890 --- /dev/null +++ b/py/visdom/server/build.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import visdom +from visdom.utils.shared_utils import ensure_dir_exists, get_visdom_path +import os +from urllib import request +from urllib.error import HTTPError, URLError + + +def download_scripts(proxies=None, install_dir=None): + """ + Function to download all of the javascript, css, and font dependencies, + and put them in the correct locations to run the server + """ + print("Checking for scripts.") + + # location in which to download stuff: + if install_dir is None: + install_dir = get_visdom_path() + + # all files that need to be downloaded: + b = 'https://unpkg.com/' + bb = '%sbootstrap@3.3.7/dist/' % b + ext_files = { + # - js + '%sjquery@3.1.1/dist/jquery.min.js' % b: 'jquery.min.js', + '%sbootstrap@3.3.7/dist/js/bootstrap.min.js' % b: 'bootstrap.min.js', + '%sreact@16.2.0/umd/react.production.min.js' % b: 'react-react.min.js', + '%sreact-dom@16.2.0/umd/react-dom.production.min.js' % b: + 'react-dom.min.js', + '%sreact-modal@3.1.10/dist/react-modal.min.js' % b: + 'react-modal.min.js', + 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_SVG': # noqa + 'mathjax-MathJax.js', + # here is another url in case the cdn breaks down again. + # https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plotly.min.js + 'https://cdn.plot.ly/plotly-latest.min.js': 'plotly-plotly.min.js', + # Stanford Javascript Crypto Library for Password Hashing + '%ssjcl@1.0.7/sjcl.js' % b: 'sjcl.js', + + # - css + '%sreact-resizable@1.4.6/css/styles.css' % b: + 'react-resizable-styles.css', + '%sreact-grid-layout@0.16.3/css/styles.css' % b: + 'react-grid-layout-styles.css', + '%scss/bootstrap.min.css' % bb: 'bootstrap.min.css', + + # - fonts + '%sclassnames@2.2.5' % b: 'classnames', + '%slayout-bin-packer@1.4.0/dist/layout-bin-packer.js' % b: + 'layout_bin_packer.js', + '%sfonts/glyphicons-halflings-regular.eot' % bb: + 'glyphicons-halflings-regular.eot', + '%sfonts/glyphicons-halflings-regular.woff2' % bb: + 'glyphicons-halflings-regular.woff2', + '%sfonts/glyphicons-halflings-regular.woff' % bb: + 'glyphicons-halflings-regular.woff', + '%sfonts/glyphicons-halflings-regular.ttf' % bb: + 'glyphicons-halflings-regular.ttf', + '%sfonts/glyphicons-halflings-regular.svg#glyphicons_halflingsregular' % bb: # noqa + 'glyphicons-halflings-regular.svg#glyphicons_halflingsregular', + } + + # make sure all relevant folders exist: + dir_list = [ + '%s' % install_dir, + '%s/static' % install_dir, + '%s/static/js' % install_dir, + '%s/static/css' % install_dir, + '%s/static/fonts' % install_dir, + ] + for directory in dir_list: + if not os.path.exists(directory): + os.makedirs(directory) + + # set up proxy handler: + handler = request.ProxyHandler(proxies) if proxies is not None \ + else request.BaseHandler() + opener = request.build_opener(handler) + request.install_opener(opener) + + built_path = os.path.join(install_dir, 'static/version.built') + is_built = visdom.__version__ == 'no_version_file' + if os.path.exists(built_path): + with open(built_path, 'r') as build_file: + build_version = build_file.read().strip() + if build_version == visdom.__version__: + is_built = True + else: + os.remove(built_path) + if not is_built: + print('Downloading scripts, this may take a little while') + + # download files one-by-one: + for (key, val) in ext_files.items(): + + # set subdirectory: + if val.endswith('.js'): + sub_dir = 'js' + elif val.endswith('.css'): + sub_dir = 'css' + else: + sub_dir = 'fonts' + + # download file: + filename = '%s/static/%s/%s' % (install_dir, sub_dir, val) + if not os.path.exists(filename) or not is_built: + req = request.Request(key, + headers={'User-Agent': 'Chrome/30.0.0.0'}) + try: + data = opener.open(req).read() + with open(filename, 'wb') as fwrite: + fwrite.write(data) + except HTTPError as exc: + logging.error('Error {} while downloading {}'.format( + exc.code, key)) + except URLError as exc: + logging.error('Error {} while downloading {}'.format( + exc.reason, key)) + + if not is_built: + with open(built_path, 'w+') as build_file: + build_file.write(visdom.__version__) diff --git a/py/visdom/server/defaults.py b/py/visdom/server/defaults.py new file mode 100644 index 00000000..99957ef5 --- /dev/null +++ b/py/visdom/server/defaults.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from os.path import expanduser + +DEFAULT_ENV_PATH = '%s/.visdom/' % expanduser("~") +DEFAULT_PORT = 8097 +DEFAULT_HOSTNAME = "localhost" +DEFAULT_BASE_URL = "/" diff --git a/py/visdom/server.py b/py/visdom/server/handlers/all_handlers.py similarity index 53% rename from py/visdom/server.py rename to py/visdom/server/handlers/all_handlers.py index 988334b6..798a5717 100644 --- a/py/visdom/server.py +++ b/py/visdom/server/handlers/all_handlers.py @@ -8,22 +8,19 @@ """Server""" -import argparse +# TODO fix these imports +from visdom.utils.shared_utils import * +from visdom.utils.server_utils import * +from visdom.server.handlers.base_handlers import * import copy import getpass import hashlib -import inspect import json import jsonpatch import logging import math import os -import sys import time -import traceback -import uuid -import warnings -from os.path import expanduser from collections import OrderedDict try: # for after python 3.8 @@ -32,273 +29,19 @@ # for python 3.7 and below from collections import Mapping, Sequence -from zmq.eventloop import ioloop -ioloop.install() # Needs to happen before any tornado imports! - import tornado.ioloop # noqa E402: gotta install ioloop first import tornado.web # noqa E402: gotta install ioloop first import tornado.websocket # noqa E402: gotta install ioloop first import tornado.escape # noqa E402: gotta install ioloop first LAYOUT_FILE = 'layouts.json' -DEFAULT_ENV_PATH = '%s/.visdom/' % expanduser("~") -DEFAULT_PORT = 8097 -DEFAULT_HOSTNAME = "localhost" -DEFAULT_BASE_URL = "/" here = os.path.abspath(os.path.dirname(__file__)) COMPACT_SEPARATORS = (',', ':') -_seen_warnings = set() - MAX_SOCKET_WAIT = 15 -assert sys.version_info[0] >= 3, 'To use visdom with python 2, downgrade to v0.1.8.9' - - -def warn_once(msg, warningtype=None): - """ - Raise a warning, but only once. - :param str msg: Message to display - :param Warning warningtype: Type of warning, e.g. DeprecationWarning - """ - global _seen_warnings - if msg not in _seen_warnings: - _seen_warnings.add(msg) - warnings.warn(msg, warningtype, stacklevel=2) - - -def check_auth(f): - def _check_auth(self, *args, **kwargs): - self.last_access = time.time() - if self.login_enabled and not self.current_user: - self.set_status(400) - return - f(self, *args, **kwargs) - return _check_auth - - -def get_rand_id(): - return str(uuid.uuid4()) - - -def ensure_dir_exists(path): - """Make sure the parent dir exists for path so we can write a file.""" - try: - os.makedirs(os.path.dirname(path)) - except OSError as e1: - assert e1.errno == 17 # errno.EEXIST - pass - - -def get_path(filename): - """Get the path to an asset.""" - cwd = os.path.dirname( - os.path.abspath(inspect.getfile(inspect.currentframe()))) - return os.path.join(cwd, filename) - - -def escape_eid(eid): - """Replace slashes with underscores, to avoid recognizing them - as directories. - """ - - return eid.replace('/', '_') - - -def extract_eid(args): - """Extract eid from args. If eid does not exist in args, - it returns 'main'.""" - - eid = 'main' if args.get('eid') is None else args.get('eid') - return escape_eid(eid) - - -def set_cookie(value=None): - """Create cookie secret key for authentication""" - if value is not None: - cookie_secret = value - else: - cookie_secret = input("Please input your cookie secret key here: ") - with open(DEFAULT_ENV_PATH + "COOKIE_SECRET", "w") as cookie_file: - cookie_file.write(cookie_secret) - - -def hash_password(password): - """Hashing Password with SHA-256""" - return hashlib.sha256(password.encode("utf-8")).hexdigest() - - -tornado_settings = { - "autoescape": None, - "debug": "/dbg/" in __file__, - "static_path": get_path('static'), - "template_path": get_path('static'), - "compiled_template_cache": False -} - - -def serialize_env(state, eids, env_path=DEFAULT_ENV_PATH): - env_ids = [i for i in eids if i in state] - if env_path is not None: - for env_id in env_ids: - env_path_file = os.path.join(env_path, "{0}.json".format(env_id)) - with open(env_path_file, 'w') as fn: - fn.write(json.dumps(state[env_id])) - return env_ids - - -def serialize_all(state, env_path=DEFAULT_ENV_PATH): - serialize_env(state, list(state.keys()), env_path=env_path) - - -class Application(tornado.web.Application): - def __init__(self, port=DEFAULT_PORT, base_url='', - env_path=DEFAULT_ENV_PATH, readonly=False, - user_credential=None, use_frontend_client_polling=False): - self.env_path = env_path - self.state = self.load_state() - self.layouts = self.load_layouts() - self.subs = {} - self.sources = {} - self.port = port - self.base_url = base_url - self.readonly = readonly - self.user_credential = user_credential - self.login_enabled = False - self.last_access = time.time() - self.wrap_socket = use_frontend_client_polling - - if user_credential: - self.login_enabled = True - with open(DEFAULT_ENV_PATH + "COOKIE_SECRET", "r") as fn: - tornado_settings["cookie_secret"] = fn.read() - - tornado_settings['static_url_prefix'] = self.base_url + "/static/" - tornado_settings['debug'] = True - handlers = [ - (r"%s/events" % self.base_url, PostHandler, {'app': self}), - (r"%s/update" % self.base_url, UpdateHandler, {'app': self}), - (r"%s/close" % self.base_url, CloseHandler, {'app': self}), - (r"%s/socket" % self.base_url, SocketHandler, {'app': self}), - (r"%s/socket_wrap" % self.base_url, SocketWrap, {'app': self}), - (r"%s/vis_socket" % self.base_url, - VisSocketHandler, {'app': self}), - (r"%s/vis_socket_wrap" % self.base_url, - VisSocketWrap, {'app': self}), - (r"%s/env/(.*)" % self.base_url, EnvHandler, {'app': self}), - (r"%s/compare/(.*)" % self.base_url, - CompareHandler, {'app': self}), - (r"%s/save" % self.base_url, SaveHandler, {'app': self}), - (r"%s/error/(.*)" % self.base_url, ErrorHandler, {'app': self}), - (r"%s/win_exists" % self.base_url, ExistsHandler, {'app': self}), - (r"%s/win_data" % self.base_url, DataHandler, {'app': self}), - (r"%s/delete_env" % self.base_url, - DeleteEnvHandler, {'app': self}), - (r"%s/win_hash" % self.base_url, HashHandler, {'app': self}), - (r"%s/env_state" % self.base_url, EnvStateHandler, {'app': self}), - (r"%s/fork_env" % self.base_url, ForkEnvHandler, {'app': self}), - (r"%s(.*)" % self.base_url, IndexHandler, {'app': self}), - ] - super(Application, self).__init__(handlers, **tornado_settings) - - def get_last_access(self): - if len(self.subs) > 0 or len(self.sources) > 0: - # update the last access time to now, as someone - # is currently connected to the server - self.last_access = time.time() - return self.last_access - - def save_layouts(self): - if self.env_path is None: - warn_once( - 'Saving and loading to disk has no effect when running with ' - 'env_path=None.', - RuntimeWarning - ) - return - layout_filepath = os.path.join(self.env_path, 'view', LAYOUT_FILE) - with open(layout_filepath, 'w') as fn: - fn.write(self.layouts) - - def load_layouts(self): - if self.env_path is None: - warn_once( - 'Saving and loading to disk has no effect when running with ' - 'env_path=None.', - RuntimeWarning - ) - return "" - layout_filepath = os.path.join(self.env_path, 'view', LAYOUT_FILE) - ensure_dir_exists(layout_filepath) - if os.path.isfile(layout_filepath): - with open(layout_filepath, 'r') as fn: - return fn.read() - else: - return "" - - def load_state(self): - state = {} - env_path = self.env_path - if env_path is None: - warn_once( - 'Saving and loading to disk has no effect when running with ' - 'env_path=None.', - RuntimeWarning - ) - return {'main': {'jsons': {}, 'reload': {}}} - ensure_dir_exists(env_path) - env_jsons = [i for i in os.listdir(env_path) if '.json' in i] - - for env_json in env_jsons: - env_path_file = os.path.join(env_path, env_json) - try: - with open(env_path_file, 'r') as fn: - env_data = tornado.escape.json_decode(fn.read()) - except Exception as e: - logging.warn( - "Failed loading environment json: {} - {}".format( - env_path_file, repr(e))) - continue - - eid = env_json.replace('.json', '') - state[eid] = {'jsons': env_data['jsons'], - 'reload': env_data['reload']} - - if 'main' not in state and 'main.json' not in env_jsons: - state['main'] = {'jsons': {}, 'reload': {}} - serialize_env(state, ['main'], env_path=self.env_path) - - return state - - -def broadcast_envs(handler, target_subs=None): - if target_subs is None: - target_subs = handler.subs.values() - for sub in target_subs: - sub.write_message(json.dumps( - {'command': 'env_update', 'data': list(handler.state.keys())} - )) - - -def send_to_sources(handler, msg): - target_sources = handler.sources.values() - for source in target_sources: - source.write_message(json.dumps(msg)) - - -class BaseWebSocketHandler(tornado.websocket.WebSocketHandler): - def get_current_user(self): - """ - This method determines the self.current_user - based the value of cookies that set in POST method - at IndexHandler by self.set_secure_cookie - """ - try: - return self.get_secure_cookie("user_password") - except Exception: # Not using secure cookies - return None - +# TODO Split this file up it's terrible class VisSocketHandler(BaseWebSocketHandler): def initialize(self, app): @@ -669,140 +412,6 @@ def get_messages(self): return to_send -class BaseHandler(tornado.web.RequestHandler): - def __init__(self, *request, **kwargs): - self.include_host = False - super(BaseHandler, self).__init__(*request, **kwargs) - - def get_current_user(self): - """ - This method determines the self.current_user - based the value of cookies that set in POST method - at IndexHandler by self.set_secure_cookie - """ - try: - return self.get_secure_cookie("user_password") - except Exception: # Not using secure cookies - return None - - def write_error(self, status_code, **kwargs): - logging.error("ERROR: %s: %s" % (status_code, kwargs)) - if "exc_info" in kwargs: - logging.info('Traceback: {}'.format( - traceback.format_exception(*kwargs["exc_info"]))) - if self.settings.get("debug") and "exc_info" in kwargs: - logging.error("rendering error page") - exc_info = kwargs["exc_info"] - # exc_info is a tuple consisting of: - # 1. The class of the Exception - # 2. The actual Exception that was thrown - # 3. The traceback opbject - try: - params = { - 'error': exc_info[1], - 'trace_info': traceback.format_exception(*exc_info), - 'request': self.request.__dict__ - } - - self.render("error.html", **params) - logging.error("rendering complete") - except Exception as e: - logging.error(e) - - -def update_window(p, args): - """Adds new args to a window if they exist""" - content = p['content'] - layout_update = args.get('layout', {}) - for layout_name, layout_val in layout_update.items(): - if layout_val is not None: - content['layout'][layout_name] = layout_val - opts = args.get('opts', {}) - for opt_name, opt_val in opts.items(): - if opt_val is not None: - p[opt_name] = opt_val - - if 'legend' in opts: - pdata = p['content']['data'] - for i, d in enumerate(pdata): - d['name'] = opts['legend'][i] - return p - - -def window(args): - """ Build a window dict structure for sending to client """ - uid = args.get('win', 'window_' + get_rand_id()) - if uid is None: - uid = 'window_' + get_rand_id() - opts = args.get('opts', {}) - - ptype = args['data'][0]['type'] - - p = { - 'command': 'window', - 'id': str(uid), - 'title': opts.get('title', ''), - 'inflate': opts.get('inflate', True), - 'width': opts.get('width'), - 'height': opts.get('height'), - 'contentID': get_rand_id(), # to detected updated windows - } - - if ptype == 'image_history': - p.update({ - 'content': [args['data'][0]['content']], - 'selected': 0, - 'type': ptype, - 'show_slider': opts.get('show_slider', True) - }) - elif ptype in ['image', 'text', 'properties']: - p.update({'content': args['data'][0]['content'], 'type': ptype}) - elif ptype in ['embeddings']: - p.update({ - 'content': args['data'][0]['content'], - 'type': ptype, - 'old_content': [], # Used to cache previous to prevent recompute - }) - p['content']['has_previous'] = False - else: - p['content'] = {'data': args['data'], 'layout': args['layout']} - p['type'] = 'plot' - - return p - - -def broadcast(self, msg, eid): - for s in self.subs: - if type(self.subs[s].eid) is list: - if eid in self.subs[s].eid: - self.subs[s].write_message(msg) - else: - if self.subs[s].eid == eid: - self.subs[s].write_message(msg) - - -def register_window(self, p, eid): - # in case env doesn't exist - is_new_env = False - if eid not in self.state: - is_new_env = True - self.state[eid] = {'jsons': {}, 'reload': {}} - - env = self.state[eid]['jsons'] - - if p['id'] in env: - p['i'] = env[p['id']]['i'] - else: - p['i'] = len(env) - - env[p['id']] = p - - broadcast(self, p, eid) - if is_new_env: - broadcast_envs(self) - self.write(p['id']) - - class PostHandler(BaseHandler): def initialize(self, app): self.state = app.state @@ -864,39 +473,6 @@ def post(self): self.wrap_func(self, args) -def order_by_key(kv): - key, val = kv - return key - - -# Based on json-stable-stringify-python from @haochi with some usecase modifications -def recursive_order(node): - if isinstance(node, Mapping): - ordered_mapping = OrderedDict(sorted(node.items(), key=order_by_key)) - for key, value in ordered_mapping.items(): - ordered_mapping[key] = recursive_order(value) - return ordered_mapping - elif isinstance(node, Sequence): - if isinstance(node, (bytes,)): - return node - elif isinstance(node, (str,)): - return node - else: - return [recursive_order(item) for item in node] - if isinstance(node, float) and node.is_integer(): - return int(node) - return node - - -def stringify(node): - return json.dumps(recursive_order(node), separators=COMPACT_SEPARATORS) - - -def hash_md_window(window_json): - json_string = stringify(window_json).encode("utf-8") - return hashlib.md5(json_string).hexdigest() - - class UpdateHandler(BaseHandler): def initialize(self, app): self.state = app.state @@ -1289,158 +865,6 @@ def post(self): self.wrap_func(self, args) -def load_env(state, eid, socket, env_path=DEFAULT_ENV_PATH): - """ load an environment to a client by socket """ - env = {} - if eid in state: - env = state.get(eid) - elif env_path is not None: - p = os.path.join(env_path, eid.strip(), '.json') - if os.path.exists(p): - with open(p, 'r') as fn: - env = tornado.escape.json_decode(fn.read()) - state[eid] = env - - if 'reload' in env: - socket.write_message( - json.dumps({'command': 'reload', 'data': env['reload']}) - ) - - jsons = list(env.get('jsons', {}).values()) - windows = sorted(jsons, key=lambda k: ('i' not in k, k.get('i', None))) - for v in windows: - socket.write_message(v) - - socket.write_message(json.dumps({'command': 'layout'})) - socket.eid = eid - - -def gather_envs(state, env_path=DEFAULT_ENV_PATH): - if env_path is not None: - items = [i.replace('.json', '') for i in os.listdir(env_path) - if '.json' in i] - else: - items = [] - return sorted(list(set(items + list(state.keys())))) - - -def compare_envs(state, eids, socket, env_path=DEFAULT_ENV_PATH): - logging.info('comparing envs') - eidNums = {e: str(i) for i, e in enumerate(eids)} - env = {} - envs = {} - for eid in eids: - if eid in state: - envs[eid] = state.get(eid) - elif env_path is not None: - p = os.path.join(env_path, eid.strip(), '.json') - if os.path.exists(p): - with open(p, 'r') as fn: - env = tornado.escape.json_decode(fn.read()) - state[eid] = env - envs[eid] = env - - res = copy.deepcopy(envs[list(envs.keys())[0]]) - name2Wid = {res['jsons'][wid].get('title', None): wid + '_compare' - for wid in res.get('jsons', {}) - if 'title' in res['jsons'][wid]} - for wid in list(res['jsons'].keys()): - res['jsons'][wid + '_compare'] = res['jsons'][wid] - res['jsons'][wid] = None - res['jsons'].pop(wid) - - for ix, eid in enumerate(envs.keys()): - env = envs[eid] - for wid in env.get('jsons', {}).keys(): - win = env['jsons'][wid] - if win.get('type', None) != 'plot': - continue - if 'content' not in win: - continue - if 'title' not in win: - continue - title = win['title'] - if title not in name2Wid or title == '': - continue - - destWid = name2Wid[title] - destWidJson = res['jsons'][destWid] - # Combine plots with the same window title. If plot data source was - # labeled "name" in the legend, rename to "envId_legend" where - # envId is enumeration of the selected environments (not the long - # environment id string). This makes plot lines more readable. - if ix == 0: - if 'name' not in destWidJson['content']['data'][0]: - continue # Skip windows with unnamed data - destWidJson['has_compare'] = False - destWidJson['content']['layout']['showlegend'] = True - destWidJson['contentID'] = get_rand_id() - for dataIdx, data in enumerate(destWidJson['content']['data']): - if 'name' not in data: - break # stop working with this plot, not right format - destWidJson['content']['data'][dataIdx]['name'] = \ - '{}_{}'.format(eidNums[eid], data['name']) - else: - if 'name' not in destWidJson['content']['data'][0]: - continue # Skip windows with unnamed data - # has_compare will be set to True only if the window title is - # shared by at least 2 envs. - destWidJson['has_compare'] = True - for _dataIdx, data in enumerate(win['content']['data']): - data = copy.deepcopy(data) - if 'name' not in data: - destWidJson['has_compare'] = False - break # stop working with this plot, not right format - data['name'] = '{}_{}'.format(eidNums[eid], data['name']) - destWidJson['content']['data'].append(data) - - # Make sure that only plots that are shared by at least two envs are shown. - # Check has_compare flag - for destWid in list(res['jsons'].keys()): - if ('has_compare' not in res['jsons'][destWid]) or \ - (not res['jsons'][destWid]['has_compare']): - del res['jsons'][destWid] - - # create legend mapping environment names to environment numbers so one can - # look it up for the new legend - tableRows = [" {} {} ".format(v, eidNums[v]) - for v in eidNums] - - tbl = """" - {}
""".format(' '.join(tableRows)) - - res['jsons']['window_compare_legend'] = { - "command": "window", - "id": "window_compare_legend", - "title": "compare_legend", - "inflate": True, - "width": None, - "height": None, - "contentID": "compare_legend", - "content": tbl, - "type": "text", - "layout": {"title": "compare_legend"}, - "i": 1, - "has_compare": True, - } - if 'reload' in res: - socket.write_message( - json.dumps({'command': 'reload', 'data': res['reload']}) - ) - - jsons = list(res.get('jsons', {}).values()) - windows = sorted(jsons, key=lambda k: ('i' not in k, k.get('i', None))) - for v in windows: - socket.write_message(v) - - socket.write_message(json.dumps({'command': 'layout'})) - socket.eid = eids - - class EnvHandler(BaseHandler): def initialize(self, app): self.state = app.state @@ -1630,261 +1054,3 @@ class ErrorHandler(BaseHandler): def get(self, text): error_text = text or "test error" raise Exception(error_text) - - -# function that downloads and installs javascript, css, and font dependencies: -def download_scripts(proxies=None, install_dir=None): - import visdom - print("Checking for scripts.") - - # location in which to download stuff: - if install_dir is None: - install_dir = os.path.dirname(visdom.__file__) - - # all files that need to be downloaded: - b = 'https://unpkg.com/' - bb = '%sbootstrap@3.3.7/dist/' % b - ext_files = { - # - js - '%sjquery@3.1.1/dist/jquery.min.js' % b: 'jquery.min.js', - '%sbootstrap@3.3.7/dist/js/bootstrap.min.js' % b: 'bootstrap.min.js', - '%sreact@16.2.0/umd/react.production.min.js' % b: 'react-react.min.js', - '%sreact-dom@16.2.0/umd/react-dom.production.min.js' % b: - 'react-dom.min.js', - '%sreact-modal@3.1.10/dist/react-modal.min.js' % b: - 'react-modal.min.js', - 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_SVG': # noqa - 'mathjax-MathJax.js', - # here is another url in case the cdn breaks down again. - # https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plotly.min.js - 'https://cdn.plot.ly/plotly-latest.min.js': 'plotly-plotly.min.js', - # Stanford Javascript Crypto Library for Password Hashing - '%ssjcl@1.0.7/sjcl.js' % b: 'sjcl.js', - - # - css - '%sreact-resizable@1.4.6/css/styles.css' % b: - 'react-resizable-styles.css', - '%sreact-grid-layout@0.16.3/css/styles.css' % b: - 'react-grid-layout-styles.css', - '%scss/bootstrap.min.css' % bb: 'bootstrap.min.css', - - # - fonts - '%sclassnames@2.2.5' % b: 'classnames', - '%slayout-bin-packer@1.4.0/dist/layout-bin-packer.js' % b: - 'layout_bin_packer.js', - '%sfonts/glyphicons-halflings-regular.eot' % bb: - 'glyphicons-halflings-regular.eot', - '%sfonts/glyphicons-halflings-regular.woff2' % bb: - 'glyphicons-halflings-regular.woff2', - '%sfonts/glyphicons-halflings-regular.woff' % bb: - 'glyphicons-halflings-regular.woff', - '%sfonts/glyphicons-halflings-regular.ttf' % bb: - 'glyphicons-halflings-regular.ttf', - '%sfonts/glyphicons-halflings-regular.svg#glyphicons_halflingsregular' % bb: # noqa - 'glyphicons-halflings-regular.svg#glyphicons_halflingsregular', - } - - # make sure all relevant folders exist: - dir_list = [ - '%s' % install_dir, - '%s/static' % install_dir, - '%s/static/js' % install_dir, - '%s/static/css' % install_dir, - '%s/static/fonts' % install_dir, - ] - for directory in dir_list: - if not os.path.exists(directory): - os.makedirs(directory) - - # set up proxy handler: - from urllib import request - from urllib.error import HTTPError, URLError - handler = request.ProxyHandler(proxies) if proxies is not None \ - else request.BaseHandler() - opener = request.build_opener(handler) - request.install_opener(opener) - - built_path = os.path.join(here, 'static/version.built') - is_built = visdom.__version__ == 'no_version_file' - if os.path.exists(built_path): - with open(built_path, 'r') as build_file: - build_version = build_file.read().strip() - if build_version == visdom.__version__: - is_built = True - else: - os.remove(built_path) - if not is_built: - print('Downloading scripts, this may take a little while') - - # download files one-by-one: - for (key, val) in ext_files.items(): - - # set subdirectory: - if val.endswith('.js'): - sub_dir = 'js' - elif val.endswith('.css'): - sub_dir = 'css' - else: - sub_dir = 'fonts' - - # download file: - filename = '%s/static/%s/%s' % (install_dir, sub_dir, val) - if not os.path.exists(filename) or not is_built: - req = request.Request(key, - headers={'User-Agent': 'Chrome/30.0.0.0'}) - try: - data = opener.open(req).read() - with open(filename, 'wb') as fwrite: - fwrite.write(data) - except HTTPError as exc: - logging.error('Error {} while downloading {}'.format( - exc.code, key)) - except URLError as exc: - logging.error('Error {} while downloading {}'.format( - exc.reason, key)) - - if not is_built: - with open(built_path, 'w+') as build_file: - build_file.write(visdom.__version__) - - -def start_server(port=DEFAULT_PORT, hostname=DEFAULT_HOSTNAME, - base_url=DEFAULT_BASE_URL, env_path=DEFAULT_ENV_PATH, - readonly=False, print_func=None, user_credential=None, - use_frontend_client_polling=False): - print("It's Alive!") - app = Application(port=port, base_url=base_url, env_path=env_path, - readonly=readonly, user_credential=user_credential, - use_frontend_client_polling=use_frontend_client_polling) - app.listen(port, max_buffer_size=1024 ** 3) - logging.info("Application Started") - - if "HOSTNAME" in os.environ and hostname == DEFAULT_HOSTNAME: - hostname = os.environ["HOSTNAME"] - else: - hostname = hostname - if print_func is None: - print( - "You can navigate to http://%s:%s%s" % (hostname, port, base_url)) - else: - print_func(port) - ioloop.IOLoop.instance().start() - app.subs = [] - app.sources = [] - - -def main(print_func=None): - parser = argparse.ArgumentParser(description='Start the visdom server.') - parser.add_argument('-port', metavar='port', type=int, - default=DEFAULT_PORT, - help='port to run the server on.') - parser.add_argument('--hostname', metavar='hostname', type=str, - default=DEFAULT_HOSTNAME, - help='host to run the server on.') - parser.add_argument('-base_url', metavar='base_url', type=str, - default=DEFAULT_BASE_URL, - help='base url for server (default = /).') - parser.add_argument('-env_path', metavar='env_path', type=str, - default=DEFAULT_ENV_PATH, - help='path to serialized session to reload.') - parser.add_argument('-logging_level', metavar='logger_level', - default='INFO', - help='logging level (default = INFO). Can take ' - 'logging level name or int (example: 20)') - parser.add_argument('-readonly', help='start in readonly mode', - action='store_true') - parser.add_argument('-enable_login', default=False, action='store_true', - help='start the server with authentication') - parser.add_argument('-force_new_cookie', default=False, - action='store_true', - help='start the server with the new cookie, ' - 'available when -enable_login provided') - parser.add_argument('-use_frontend_client_polling', default=False, - action='store_true', - help='Have the frontend communicate via polling ' - 'rather than over websockets.') - FLAGS = parser.parse_args() - - # Process base_url - base_url = FLAGS.base_url if FLAGS.base_url != DEFAULT_BASE_URL else "" - assert base_url == '' or base_url.startswith('/'), \ - 'base_url should start with /' - assert base_url == '' or not base_url.endswith('/'), \ - 'base_url should not end with / as it is appended automatically' - - try: - logging_level = int(FLAGS.logging_level) - except (ValueError,): - try: - logging_level = logging._checkLevel(FLAGS.logging_level) - except ValueError: - raise KeyError( - "Invalid logging level : {0}".format(FLAGS.logging_level) - ) - - logging.getLogger().setLevel(logging_level) - - if FLAGS.enable_login: - enable_env_login = 'VISDOM_USE_ENV_CREDENTIALS' - use_env = os.environ.get(enable_env_login, False) - if use_env: - username_var = 'VISDOM_USERNAME' - password_var = 'VISDOM_PASSWORD' - username = os.environ.get(username_var) - password = os.environ.get(password_var) - if not (username and password): - print( - '*** Warning ***\n' - 'You have set the {0} env variable but probably ' - 'forgot to setup one (or both) {{ {1}, {2} }} ' - 'variables.\nYou should setup these variables with ' - 'proper username and password to enable logging. Try to ' - 'setup the variables, or unset {0} to input credentials ' - 'via command line prompt instead.\n' - .format(enable_env_login, username_var, password_var)) - sys.exit(1) - - else: - username = input("Please input your username: ") - password = getpass.getpass(prompt="Please input your password: ") - - user_credential = { - "username": username, - "password": hash_password(hash_password(password)) - } - - need_to_set_cookie = ( - not os.path.isfile(DEFAULT_ENV_PATH + "COOKIE_SECRET") - or FLAGS.force_new_cookie) - - if need_to_set_cookie: - if use_env: - cookie_var = 'VISDOM_COOKIE' - env_cookie = os.environ.get(cookie_var) - if env_cookie is None: - print( - 'The cookie file is not found. Please setup {0} env ' - 'variable to provide a cookie value, or unset {1} env ' - 'variable to input credentials and cookie via command ' - 'line prompt.'.format(cookie_var, enable_env_login)) - sys.exit(1) - else: - env_cookie = None - set_cookie(env_cookie) - - else: - user_credential = None - - start_server(port=FLAGS.port, hostname=FLAGS.hostname, base_url=base_url, - env_path=FLAGS.env_path, readonly=FLAGS.readonly, - print_func=print_func, user_credential=user_credential, - use_frontend_client_polling=FLAGS.use_frontend_client_polling) - - -def download_scripts_and_run(): - download_scripts() - main() - - -if __name__ == "__main__": - download_scripts_and_run() diff --git a/py/visdom/server/handlers/base_handlers.py b/py/visdom/server/handlers/base_handlers.py new file mode 100644 index 00000000..e6d3f038 --- /dev/null +++ b/py/visdom/server/handlers/base_handlers.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Server""" + +from visdom.utils.shared_utils import ( + warn_once, + get_rand_id, + get_new_window_id, + ensure_dir_exists, +) +import argparse +import copy +import getpass +import hashlib +import inspect +import json +import jsonpatch +import logging +import math +import os +import sys +import time +import traceback +from collections import OrderedDict +try: + # for after python 3.8 + from collections.abc import Mapping, Sequence +except ImportError: + # for python 3.7 and below + from collections import Mapping, Sequence + +# from zmq.eventloop import ioloop +# ioloop.install() # Needs to happen before any tornado imports! + +import tornado.ioloop # noqa E402: gotta install ioloop first +import tornado.web # noqa E402: gotta install ioloop first +import tornado.websocket # noqa E402: gotta install ioloop first +import tornado.escape # noqa E402: gotta install ioloop first + +LAYOUT_FILE = 'layouts.json' + +COMPACT_SEPARATORS = (',', ':') + +MAX_SOCKET_WAIT = 15 + +class BaseWebSocketHandler(tornado.websocket.WebSocketHandler): + def get_current_user(self): + """ + This method determines the self.current_user + based the value of cookies that set in POST method + at IndexHandler by self.set_secure_cookie + """ + try: + return self.get_secure_cookie("user_password") + except Exception: # Not using secure cookies + return None + + +class BaseHandler(tornado.web.RequestHandler): + def __init__(self, *request, **kwargs): + self.include_host = False + super(BaseHandler, self).__init__(*request, **kwargs) + + def get_current_user(self): + """ + This method determines the self.current_user + based the value of cookies that set in POST method + at IndexHandler by self.set_secure_cookie + """ + try: + return self.get_secure_cookie("user_password") + except Exception: # Not using secure cookies + return None + + def write_error(self, status_code, **kwargs): + logging.error("ERROR: %s: %s" % (status_code, kwargs)) + if "exc_info" in kwargs: + logging.info('Traceback: {}'.format( + traceback.format_exception(*kwargs["exc_info"]))) + if self.settings.get("debug") and "exc_info" in kwargs: + logging.error("rendering error page") + exc_info = kwargs["exc_info"] + # exc_info is a tuple consisting of: + # 1. The class of the Exception + # 2. The actual Exception that was thrown + # 3. The traceback opbject + try: + params = { + 'error': exc_info[1], + 'trace_info': traceback.format_exception(*exc_info), + 'request': self.request.__dict__ + } + + self.render("error.html", **params) + logging.error("rendering complete") + except Exception as e: + logging.error(e) diff --git a/py/visdom/server/run_server.py b/py/visdom/server/run_server.py new file mode 100644 index 00000000..0f6e90c3 --- /dev/null +++ b/py/visdom/server/run_server.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from visdom.server.app import Application +from visdom.server.defaults import ( + DEFAULT_BASE_URL, + DEFAULT_ENV_PATH, + DEFAULT_HOSTNAME, + DEFAULT_PORT, +) +from visdom.server.build import download_scripts + +import argparse +import getpass +import logging +import os +import sys + +from tornado import ioloop + + +def start_server(port=DEFAULT_PORT, hostname=DEFAULT_HOSTNAME, + base_url=DEFAULT_BASE_URL, env_path=DEFAULT_ENV_PATH, + readonly=False, print_func=None, user_credential=None, + use_frontend_client_polling=False): + """Run a visdom server with the given arguments""" + logging.info("It's Alive!") + app = Application(port=port, base_url=base_url, env_path=env_path, + readonly=readonly, user_credential=user_credential, + use_frontend_client_polling=use_frontend_client_polling) + app.listen(port, max_buffer_size=1024 ** 3) + logging.info("Application Started") + + if "HOSTNAME" in os.environ and hostname == DEFAULT_HOSTNAME: + hostname = os.environ["HOSTNAME"] + else: + hostname = hostname + if print_func is None: + print( + "You can navigate to http://%s:%s%s" % (hostname, port, base_url)) + else: + print_func(port) + ioloop.IOLoop.instance().start() + app.subs = [] + app.sources = [] + + +def main(print_func=None): + """ + Run a server from the command line, first parsing arguments from the + command line + """ + parser = argparse.ArgumentParser(description='Start the visdom server.') + parser.add_argument('-port', metavar='port', type=int, + default=DEFAULT_PORT, + help='port to run the server on.') + parser.add_argument('--hostname', metavar='hostname', type=str, + default=DEFAULT_HOSTNAME, + help='host to run the server on.') + parser.add_argument('-base_url', metavar='base_url', type=str, + default=DEFAULT_BASE_URL, + help='base url for server (default = /).') + parser.add_argument('-env_path', metavar='env_path', type=str, + default=DEFAULT_ENV_PATH, + help='path to serialized session to reload.') + parser.add_argument('-logging_level', metavar='logger_level', + default='INFO', + help='logging level (default = INFO). Can take ' + 'logging level name or int (example: 20)') + parser.add_argument('-readonly', help='start in readonly mode', + action='store_true') + parser.add_argument('-enable_login', default=False, action='store_true', + help='start the server with authentication') + parser.add_argument('-force_new_cookie', default=False, + action='store_true', + help='start the server with the new cookie, ' + 'available when -enable_login provided') + parser.add_argument('-use_frontend_client_polling', default=False, + action='store_true', + help='Have the frontend communicate via polling ' + 'rather than over websockets.') + FLAGS = parser.parse_args() + + # Process base_url + base_url = FLAGS.base_url if FLAGS.base_url != DEFAULT_BASE_URL else "" + assert base_url == '' or base_url.startswith('/'), \ + 'base_url should start with /' + assert base_url == '' or not base_url.endswith('/'), \ + 'base_url should not end with / as it is appended automatically' + + try: + logging_level = int(FLAGS.logging_level) + except (ValueError,): + try: + logging_level = logging._checkLevel(FLAGS.logging_level) + except ValueError: + raise KeyError( + "Invalid logging level : {0}".format(FLAGS.logging_level) + ) + + logging.getLogger().setLevel(logging_level) + + if FLAGS.enable_login: + enable_env_login = 'VISDOM_USE_ENV_CREDENTIALS' + use_env = os.environ.get(enable_env_login, False) + if use_env: + username_var = 'VISDOM_USERNAME' + password_var = 'VISDOM_PASSWORD' + username = os.environ.get(username_var) + password = os.environ.get(password_var) + if not (username and password): + print( + '*** Warning ***\n' + 'You have set the {0} env variable but probably ' + 'forgot to setup one (or both) {{ {1}, {2} }} ' + 'variables.\nYou should setup these variables with ' + 'proper username and password to enable logging. Try to ' + 'setup the variables, or unset {0} to input credentials ' + 'via command line prompt instead.\n' + .format(enable_env_login, username_var, password_var)) + sys.exit(1) + + else: + username = input("Please input your username: ") + password = getpass.getpass(prompt="Please input your password: ") + + user_credential = { + "username": username, + "password": hash_password(hash_password(password)) + } + + need_to_set_cookie = ( + not os.path.isfile(DEFAULT_ENV_PATH + "COOKIE_SECRET") + or FLAGS.force_new_cookie) + + if need_to_set_cookie: + if use_env: + cookie_var = 'VISDOM_COOKIE' + env_cookie = os.environ.get(cookie_var) + if env_cookie is None: + print( + 'The cookie file is not found. Please setup {0} env ' + 'variable to provide a cookie value, or unset {1} env ' + 'variable to input credentials and cookie via command ' + 'line prompt.'.format(cookie_var, enable_env_login)) + sys.exit(1) + else: + env_cookie = None + set_cookie(env_cookie) + + else: + user_credential = None + + start_server(port=FLAGS.port, hostname=FLAGS.hostname, base_url=base_url, + env_path=FLAGS.env_path, readonly=FLAGS.readonly, + print_func=print_func, user_credential=user_credential, + use_frontend_client_polling=FLAGS.use_frontend_client_polling) + + +def download_scripts_and_run(): + download_scripts() + main() + + +if __name__ == "__main__": + download_scripts_and_run() diff --git a/py/visdom/utils/server_utils.py b/py/visdom/utils/server_utils.py new file mode 100644 index 00000000..bd16d6a6 --- /dev/null +++ b/py/visdom/utils/server_utils.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for the server architecture that don't really have +a more appropriate place. + +At the moment, this just inherited all of the floating functions +in the previous server.py class. +""" + +from visdom.server.defaults import ( + DEFAULT_BASE_URL, + DEFAULT_ENV_PATH, + DEFAULT_HOSTNAME, + DEFAULT_PORT, +) +from visdom.utils.shared_utils import get_new_window_id +from visdom.utils.shared_utils import ( + warn_once, + get_rand_id, + get_new_window_id, + ensure_dir_exists, +) +import copy +import hashlib +import json +import logging +import os +import time +from collections import OrderedDict +try: + # for after python 3.8 + from collections.abc import Mapping, Sequence +except ImportError: + # for python 3.7 and below + from collections import Mapping, Sequence + +from zmq.eventloop import ioloop +ioloop.install() # Needs to happen before any tornado imports! + +import tornado.escape # noqa E402: gotta install ioloop first + +LAYOUT_FILE = 'layouts.json' + +here = os.path.abspath(os.path.dirname(__file__)) +COMPACT_SEPARATORS = (',', ':') + +MAX_SOCKET_WAIT = 15 + +# ---- Vaguely server-security related functions ---- # + +def check_auth(f): + """ + Wrapper for server access methods to ensure that the access + is authorized. + """ + def _check_auth(app, *args, **kwargs): + app.last_access = time.time() + if app.login_enabled and not app.current_user: + app.set_status(400) + return + f(app, *args, **kwargs) + return _check_auth + +def set_cookie(value=None): + """Create cookie secret key for authentication""" + if value is not None: + cookie_secret = value + else: + cookie_secret = input("Please input your cookie secret key here: ") + with open(DEFAULT_ENV_PATH + "COOKIE_SECRET", "w") as cookie_file: + cookie_file.write(cookie_secret) + +def hash_password(password): + """Hashing Password with SHA-256""" + return hashlib.sha256(password.encode("utf-8")).hexdigest() + + +# ------- File management helprs ----- # + +def serialize_env(state, eids, env_path=DEFAULT_ENV_PATH): + env_ids = [i for i in eids if i in state] + if env_path is not None: + for env_id in env_ids: + env_path_file = os.path.join(env_path, "{0}.json".format(env_id)) + with open(env_path_file, 'w') as fn: + fn.write(json.dumps(state[env_id])) + return env_ids + + +def serialize_all(state, env_path=DEFAULT_ENV_PATH): + serialize_env(state, list(state.keys()), env_path=env_path) + + +# ------- Environment management helpers ----- # + + +def escape_eid(eid): + """Replace slashes with underscores, to avoid recognizing them + as directories. + """ + return eid.replace('/', '_') + + +def extract_eid(args): + """Extract eid from args. If eid does not exist in args, + it returns 'main'.""" + eid = 'main' if args.get('eid') is None else args.get('eid') + return escape_eid(eid) + + +def update_window(p, args): + """Adds new args to a window if they exist""" + content = p['content'] + layout_update = args.get('layout', {}) + for layout_name, layout_val in layout_update.items(): + if layout_val is not None: + content['layout'][layout_name] = layout_val + opts = args.get('opts', {}) + for opt_name, opt_val in opts.items(): + if opt_val is not None: + p[opt_name] = opt_val + + if 'legend' in opts: + pdata = p['content']['data'] + for i, d in enumerate(pdata): + d['name'] = opts['legend'][i] + return p + + +def window(args): + """ Build a window dict structure for sending to client """ + uid = args.get('win', get_new_window_id()) + if uid is None: + uid = get_new_window_id() + opts = args.get('opts', {}) + + ptype = args['data'][0]['type'] + + p = { + 'command': 'window', + 'id': str(uid), + 'title': opts.get('title', ''), + 'inflate': opts.get('inflate', True), + 'width': opts.get('width'), + 'height': opts.get('height'), + 'contentID': get_rand_id(), # to detected updated windows + } + + if ptype == 'image_history': + p.update({ + 'content': [args['data'][0]['content']], + 'selected': 0, + 'type': ptype, + 'show_slider': opts.get('show_slider', True) + }) + elif ptype in ['image', 'text', 'properties']: + p.update({'content': args['data'][0]['content'], 'type': ptype}) + elif ptype in ['embeddings']: + p.update({ + 'content': args['data'][0]['content'], + 'type': ptype, + 'old_content': [], # Used to cache previous to prevent recompute + }) + p['content']['has_previous'] = False + else: + p['content'] = {'data': args['data'], 'layout': args['layout']} + p['type'] = 'plot' + + return p + + +def gather_envs(state, env_path=DEFAULT_ENV_PATH): + if env_path is not None: + items = [i.replace('.json', '') for i in os.listdir(env_path) + if '.json' in i] + else: + items = [] + return sorted(list(set(items + list(state.keys())))) + + +def compare_envs(state, eids, socket, env_path=DEFAULT_ENV_PATH): + logging.info('comparing envs') + eidNums = {e: str(i) for i, e in enumerate(eids)} + env = {} + envs = {} + for eid in eids: + if eid in state: + envs[eid] = state.get(eid) + elif env_path is not None: + p = os.path.join(env_path, eid.strip(), '.json') + if os.path.exists(p): + with open(p, 'r') as fn: + env = tornado.escape.json_decode(fn.read()) + state[eid] = env + envs[eid] = env + + res = copy.deepcopy(envs[list(envs.keys())[0]]) + name2Wid = {res['jsons'][wid].get('title', None): wid + '_compare' + for wid in res.get('jsons', {}) + if 'title' in res['jsons'][wid]} + for wid in list(res['jsons'].keys()): + res['jsons'][wid + '_compare'] = res['jsons'][wid] + res['jsons'][wid] = None + res['jsons'].pop(wid) + + for ix, eid in enumerate(envs.keys()): + env = envs[eid] + for wid in env.get('jsons', {}).keys(): + win = env['jsons'][wid] + if win.get('type', None) != 'plot': + continue + if 'content' not in win: + continue + if 'title' not in win: + continue + title = win['title'] + if title not in name2Wid or title == '': + continue + + destWid = name2Wid[title] + destWidJson = res['jsons'][destWid] + # Combine plots with the same window title. If plot data source was + # labeled "name" in the legend, rename to "envId_legend" where + # envId is enumeration of the selected environments (not the long + # environment id string). This makes plot lines more readable. + if ix == 0: + if 'name' not in destWidJson['content']['data'][0]: + continue # Skip windows with unnamed data + destWidJson['has_compare'] = False + destWidJson['content']['layout']['showlegend'] = True + destWidJson['contentID'] = get_rand_id() + for dataIdx, data in enumerate(destWidJson['content']['data']): + if 'name' not in data: + break # stop working with this plot, not right format + destWidJson['content']['data'][dataIdx]['name'] = \ + '{}_{}'.format(eidNums[eid], data['name']) + else: + if 'name' not in destWidJson['content']['data'][0]: + continue # Skip windows with unnamed data + # has_compare will be set to True only if the window title is + # shared by at least 2 envs. + destWidJson['has_compare'] = True + for _dataIdx, data in enumerate(win['content']['data']): + data = copy.deepcopy(data) + if 'name' not in data: + destWidJson['has_compare'] = False + break # stop working with this plot, not right format + data['name'] = '{}_{}'.format(eidNums[eid], data['name']) + destWidJson['content']['data'].append(data) + + # Make sure that only plots that are shared by at least two envs are shown. + # Check has_compare flag + for destWid in list(res['jsons'].keys()): + if ('has_compare' not in res['jsons'][destWid]) or \ + (not res['jsons'][destWid]['has_compare']): + del res['jsons'][destWid] + + # create legend mapping environment names to environment numbers so one can + # look it up for the new legend + tableRows = [" {} {} ".format(v, eidNums[v]) + for v in eidNums] + + tbl = """" + {}
""".format(' '.join(tableRows)) + + res['jsons']['window_compare_legend'] = { + "command": "window", + "id": "window_compare_legend", + "title": "compare_legend", + "inflate": True, + "width": None, + "height": None, + "contentID": "compare_legend", + "content": tbl, + "type": "text", + "layout": {"title": "compare_legend"}, + "i": 1, + "has_compare": True, + } + if 'reload' in res: + socket.write_message( + json.dumps({'command': 'reload', 'data': res['reload']}) + ) + + jsons = list(res.get('jsons', {}).values()) + windows = sorted(jsons, key=lambda k: ('i' not in k, k.get('i', None))) + for v in windows: + socket.write_message(v) + + socket.write_message(json.dumps({'command': 'layout'})) + socket.eid = eids + + + +# ------- Broadcasting functions ---------- # + +def broadcast_envs(handler, target_subs=None): + if target_subs is None: + target_subs = handler.subs.values() + for sub in target_subs: + sub.write_message(json.dumps( + {'command': 'env_update', 'data': list(handler.state.keys())} + )) + + +def send_to_sources(handler, msg): + target_sources = handler.sources.values() + for source in target_sources: + source.write_message(json.dumps(msg)) + + +def load_env(state, eid, socket, env_path=DEFAULT_ENV_PATH): + """ load an environment to a client by socket """ + env = {} + if eid in state: + env = state.get(eid) + elif env_path is not None: + p = os.path.join(env_path, eid.strip(), '.json') + if os.path.exists(p): + with open(p, 'r') as fn: + env = tornado.escape.json_decode(fn.read()) + state[eid] = env + + if 'reload' in env: + socket.write_message( + json.dumps({'command': 'reload', 'data': env['reload']}) + ) + + jsons = list(env.get('jsons', {}).values()) + windows = sorted(jsons, key=lambda k: ('i' not in k, k.get('i', None))) + for v in windows: + socket.write_message(v) + + socket.write_message(json.dumps({'command': 'layout'})) + socket.eid = eid + + +def broadcast(self, msg, eid): + for s in self.subs: + if type(self.subs[s].eid) is list: + if eid in self.subs[s].eid: + self.subs[s].write_message(msg) + else: + if self.subs[s].eid == eid: + self.subs[s].write_message(msg) + + +def register_window(self, p, eid): + # in case env doesn't exist + is_new_env = False + if eid not in self.state: + is_new_env = True + self.state[eid] = {'jsons': {}, 'reload': {}} + + env = self.state[eid]['jsons'] + + if p['id'] in env: + p['i'] = env[p['id']]['i'] + else: + p['i'] = len(env) + + env[p['id']] = p + + broadcast(self, p, eid) + if is_new_env: + broadcast_envs(self) + self.write(p['id']) + + +# ----- Json patch helpers ---------- # + + +def order_by_key(kv): + key, val = kv + return key + + +# Based on json-stable-stringify-python from @haochi with some usecase modifications +def recursive_order(node): + if isinstance(node, Mapping): + ordered_mapping = OrderedDict(sorted(node.items(), key=order_by_key)) + for key, value in ordered_mapping.items(): + ordered_mapping[key] = recursive_order(value) + return ordered_mapping + elif isinstance(node, Sequence): + if isinstance(node, (bytes,)): + return node + elif isinstance(node, (str,)): + return node + else: + return [recursive_order(item) for item in node] + if isinstance(node, float) and node.is_integer(): + return int(node) + return node + + +def stringify(node): + return json.dumps(recursive_order(node), separators=COMPACT_SEPARATORS) + + +def hash_md_window(window_json): + json_string = stringify(window_json).encode("utf-8") + return hashlib.md5(json_string).hexdigest() diff --git a/py/visdom/utils/shared_utils.py b/py/visdom/utils/shared_utils.py new file mode 100644 index 00000000..23c164d4 --- /dev/null +++ b/py/visdom/utils/shared_utils.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities that could be potentially useful in various different +parts of the visdom stack. Not to be used for particularly specific +helper functions. +""" + +import inspect +import uuid +import warnings +import os + +_seen_warnings = set() + + +def warn_once(msg, warningtype=None): + """ + Raise a warning, but only once. + :param str msg: Message to display + :param Warning warningtype: Type of warning, e.g. DeprecationWarning + """ + global _seen_warnings + if msg not in _seen_warnings: + _seen_warnings.add(msg) + warnings.warn(msg, warningtype, stacklevel=2) + + +def get_rand_id(): + """Returns a random id string""" + return str(uuid.uuid4()) + + +def get_new_window_id(): + """Return a string to be used for a new window""" + return f'win_{get_rand_id()}' + + +def ensure_dir_exists(path): + """Make sure the parent dir exists for path so we can write a file.""" + try: + os.makedirs(os.path.dirname(path)) + except OSError as e1: + assert e1.errno == 17 # errno.EEXIST + + +def get_visdom_path(): + """Get the path to the visdom/py/visdom directory.""" + cwd = os.path.dirname( + os.path.abspath(inspect.getfile(inspect.currentframe()))) + return os.path.dirname(cwd) + + +def get_visdom_path_to(filename): + """Get the path to a file in the visdom/py/visdom directory.""" + return os.path.join(get_visdom_path(), filename) diff --git a/setup.py b/setup.py index 73c79470..b5d0e97c 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def get_dist(pkgname): 'scipy', 'requests', 'tornado', - 'pyzmq', + # 'pyzmq', 'six', 'jsonpatch', 'websocket-client', @@ -57,6 +57,7 @@ def get_dist(pkgname): url='https://github.com/facebookresearch/visdom', description='A tool for visualizing live, rich data for Torch and Numpy', long_description=readme, + long_description_content_type="text/markdown", license='CC-BY-NC-4.0', # Package info From fa1b3e180aab92eddc2aa4e5f6c3830b3a6992d6 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Mon, 23 Sep 2019 00:02:11 -0400 Subject: [PATCH 9/9] Splitting all handlers into socket and base handlers --- py/visdom/server/app.py | 3 +- py/visdom/server/handlers/base_handlers.py | 54 +- py/visdom/server/handlers/socket_handlers.py | 515 ++++++++++++++++++ .../{all_handlers.py => web_handlers.py} | 499 +---------------- py/visdom/server/run_server.py | 4 + py/visdom/utils/server_utils.py | 11 +- 6 files changed, 558 insertions(+), 528 deletions(-) create mode 100644 py/visdom/server/handlers/socket_handlers.py rename py/visdom/server/handlers/{all_handlers.py => web_handlers.py} (50%) diff --git a/py/visdom/server/app.py b/py/visdom/server/app.py index 955ab74f..7f64fef3 100644 --- a/py/visdom/server/app.py +++ b/py/visdom/server/app.py @@ -18,7 +18,8 @@ ) # TODO replace this next -from visdom.server.handlers.all_handlers import * +from visdom.server.handlers.socket_handlers import * +from visdom.server.handlers.web_handlers import * import copy import hashlib diff --git a/py/visdom/server/handlers/base_handlers.py b/py/visdom/server/handlers/base_handlers.py index e6d3f038..21607999 100644 --- a/py/visdom/server/handlers/base_handlers.py +++ b/py/visdom/server/handlers/base_handlers.py @@ -6,50 +6,24 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -"""Server""" +""" +Contain the basic web request handlers that all other handlers derive from +""" -from visdom.utils.shared_utils import ( - warn_once, - get_rand_id, - get_new_window_id, - ensure_dir_exists, -) -import argparse -import copy -import getpass -import hashlib -import inspect -import json -import jsonpatch import logging -import math -import os -import sys -import time import traceback -from collections import OrderedDict -try: - # for after python 3.8 - from collections.abc import Mapping, Sequence -except ImportError: - # for python 3.7 and below - from collections import Mapping, Sequence -# from zmq.eventloop import ioloop -# ioloop.install() # Needs to happen before any tornado imports! +import tornado.web +import tornado.websocket -import tornado.ioloop # noqa E402: gotta install ioloop first -import tornado.web # noqa E402: gotta install ioloop first -import tornado.websocket # noqa E402: gotta install ioloop first -import tornado.escape # noqa E402: gotta install ioloop first - -LAYOUT_FILE = 'layouts.json' - -COMPACT_SEPARATORS = (',', ':') - -MAX_SOCKET_WAIT = 15 class BaseWebSocketHandler(tornado.websocket.WebSocketHandler): + """ + Implements any required overriden functionality from the basic tornado + websocket handler. Also contains some shared logic for all WebSocketHandler + classes. + """ + def get_current_user(self): """ This method determines the self.current_user @@ -63,6 +37,11 @@ def get_current_user(self): class BaseHandler(tornado.web.RequestHandler): + """ + Implements any required overriden functionality from the basic tornado + request handlers, and contains any convenient shared logic helpers. + """ + def __init__(self, *request, **kwargs): self.include_host = False super(BaseHandler, self).__init__(*request, **kwargs) @@ -97,6 +76,7 @@ def write_error(self, status_code, **kwargs): 'request': self.request.__dict__ } + # TODO make an error.html page self.render("error.html", **params) logging.error("rendering complete") except Exception as e: diff --git a/py/visdom/server/handlers/socket_handlers.py b/py/visdom/server/handlers/socket_handlers.py new file mode 100644 index 00000000..d3d1c891 --- /dev/null +++ b/py/visdom/server/handlers/socket_handlers.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 + +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Handlers for the different types of socket events. Mostly handles parsing and +processing the web events themselves and interfacing with the server as +necessary, but defers underlying manipulations of the server's data to +the data_model itself. +""" + +# TODO fix these imports +from visdom.utils.shared_utils import * +from visdom.utils.server_utils import * +from visdom.server.handlers.base_handlers import * +import copy +import getpass +import hashlib +import json +import jsonpatch +import logging +import math +import os +import time +from collections import OrderedDict +try: + # for after python 3.8 + from collections.abc import Mapping, Sequence +except ImportError: + # for python 3.7 and below + from collections import Mapping, Sequence + +import tornado.ioloop +import tornado.escape + +MAX_SOCKET_WAIT = 15 + +# TODO move the logic that actually parses environments and layouts to +# new classes in the data_model folder. +# TODO move generalized initialization logic from these handlers into the +# basehandler +# TODO abstract out any direct references to the app where possible from +# all handlers. Can instead provide accessor functions on the state? +# TODO abstract socket interaction logic such that both the regular +# sockets and the poll-based wrappers are using as much shared code as +# possible. Try to standardize the code between the client-server and +# visdom-server socket edges. +class VisSocketHandler(BaseWebSocketHandler): + def initialize(self, app): + self.state = app.state + self.subs = app.subs + self.sources = app.sources + self.port = app.port + self.env_path = app.env_path + self.login_enabled = app.login_enabled + + def check_origin(self, origin): + return True + + def open(self): + if self.login_enabled and not self.current_user: + self.close() + return + self.sid = str(hex(int(time.time() * 10000000))[2:]) + if self not in list(self.sources.values()): + self.eid = 'main' + self.sources[self.sid] = self + logging.info('Opened visdom socket from ip: {}'.format( + self.request.remote_ip)) + + self.write_message( + json.dumps({'command': 'alive', 'data': 'vis_alive'})) + + def on_message(self, message): + logging.info('from visdom client: {}'.format(message)) + msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) + + cmd = msg.get('cmd') + if cmd == 'echo': + for sub in self.sources.values(): + sub.write_message(json.dumps(msg)) + + def on_close(self): + if self in list(self.sources.values()): + self.sources.pop(self.sid, None) + + +class VisSocketWrapper(): + def __init__(self, app): + self.state = app.state + self.subs = app.subs + self.sources = app.sources + self.port = app.port + self.env_path = app.env_path + self.login_enabled = app.login_enabled + self.app = app + self.messages = [] + self.last_read_time = time.time() + self.open() + try: + if not self.app.socket_wrap_monitor.is_running(): + self.app.socket_wrap_monitor.start() + except AttributeError: + self.app.socket_wrap_monitor = tornado.ioloop.PeriodicCallback( + self.socket_wrap_monitor_thread, 15000 + ) + self.app.socket_wrap_monitor.start() + + # TODO refactor the two socket wrappers into a wrapper class + def socket_wrap_monitor_thread(self): + if len(self.subs) > 0 or len(self.sources) > 0: + for sub in list(self.subs.values()): + if time.time() - sub.last_read_time > MAX_SOCKET_WAIT: + sub.close() + for sub in list(self.sources.values()): + if time.time() - sub.last_read_time > MAX_SOCKET_WAIT: + sub.close() + else: + self.app.socket_wrap_monitor.stop() + + def open(self): + if self.login_enabled and not self.current_user: + print("AUTH Failed in SocketHandler") + self.close() + return + self.sid = get_rand_id() + if self not in list(self.sources.values()): + self.eid = 'main' + self.sources[self.sid] = self + logging.info('Mocking visdom socket: {}'.format(self.sid)) + + self.write_message( + json.dumps({'command': 'alive', 'data': 'vis_alive'})) + + def on_message(self, message): + logging.info('from visdom client: {}'.format(message)) + msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) + + cmd = msg.get('cmd') + if cmd == 'echo': + for sub in self.sources.values(): + sub.write_message(json.dumps(msg)) + + def close(self): + if self in list(self.sources.values()): + self.sources.pop(self.sid, None) + + def write_message(self, msg): + self.messages.append(msg) + + def get_messages(self): + to_send = [] + while len(self.messages) > 0: + message = self.messages.pop() + if type(message) is dict: + # Not all messages are being formatted the same way (JSON) + # TODO investigate + message = json.dumps(message) + to_send.append(message) + self.last_read_time = time.time() + return to_send + + +class SocketHandler(BaseWebSocketHandler): + def initialize(self, app): + self.port = app.port + self.env_path = app.env_path + self.app = app + self.state = app.state + self.subs = app.subs + self.sources = app.sources + self.broadcast_layouts() + self.readonly = app.readonly + self.login_enabled = app.login_enabled + + def check_origin(self, origin): + return True + + def broadcast_layouts(self, target_subs=None): + if target_subs is None: + target_subs = self.subs.values() + for sub in target_subs: + sub.write_message(json.dumps( + {'command': 'layout_update', 'data': self.app.layouts} + )) + + def open(self): + if self.login_enabled and not self.current_user: + print("AUTH Failed in SocketHandler") + self.close() + return + self.sid = get_rand_id() + if self not in list(self.subs.values()): + self.eid = 'main' + self.subs[self.sid] = self + logging.info( + 'Opened new socket from ip: {}'.format(self.request.remote_ip)) + + self.write_message( + json.dumps({'command': 'register', 'data': self.sid, + 'readonly': self.readonly})) + self.broadcast_layouts([self]) + broadcast_envs(self, [self]) + + def on_message(self, message): + logging.info('from web client: {}'.format(message)) + msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) + + cmd = msg.get('cmd') + + if self.readonly: + return + + if cmd == 'close': + if 'data' in msg and 'eid' in msg: + logging.info('closing window {}'.format(msg['data'])) + p_data = self.state[msg['eid']]['jsons'].pop(msg['data'], None) + event = { + 'event_type': 'close', + 'target': msg['data'], + 'eid': msg['eid'], + 'pane_data': p_data, + } + send_to_sources(self, event) + elif cmd == 'save': + # save localStorage window metadata + if 'data' in msg and 'eid' in msg: + msg['eid'] = escape_eid(msg['eid']) + self.state[msg['eid']] = \ + copy.deepcopy(self.state[msg['prev_eid']]) + self.state[msg['eid']]['reload'] = msg['data'] + self.eid = msg['eid'] + serialize_env(self.state, [self.eid], env_path=self.env_path) + elif cmd == 'delete_env': + if 'eid' in msg: + logging.info('closing environment {}'.format(msg['eid'])) + del self.state[msg['eid']] + if self.env_path is not None: + p = os.path.join( + self.env_path, + "{0}.json".format(msg['eid']) + ) + os.remove(p) + broadcast_envs(self) + elif cmd == 'save_layouts': + if 'data' in msg: + self.app.layouts = msg.get('data') + self.app.save_layouts() + self.broadcast_layouts() + elif cmd == 'forward_to_vis': + packet = msg.get('data') + environment = self.state[packet['eid']] + if packet.get('pane_data') is not False: + packet['pane_data'] = environment['jsons'][packet['target']] + send_to_sources(self, msg.get('data')) + elif cmd == 'layout_item_update': + eid = msg.get('eid') + win = msg.get('win') + self.state[eid]['reload'][win] = msg.get('data') + elif cmd == 'pop_embeddings_pane': + packet = msg.get('data') + eid = packet['eid'] + win = packet['target'] + p = self.state[eid]['jsons'][win] + p['content']['selected'] = None + p['content']['data'] = p['old_content'].pop() + if len(p['old_content']) == 0: + p['content']['has_previous'] = False + p['contentID'] = get_rand_id() + broadcast(self, p, eid) + + def on_close(self): + if self in list(self.subs.values()): + self.subs.pop(self.sid, None) + + +# TODO condense some of the functionality between this class and the +# original SocketHandler class +class ClientSocketWrapper(): + """ + Wraps all of the socket actions in regular request handling, thus + allowing all of the same information to be sent via a polling interface + """ + def __init__(self, app): + self.port = app.port + self.env_path = app.env_path + self.app = app + self.state = app.state + self.subs = app.subs + self.sources = app.sources + self.readonly = app.readonly + self.login_enabled = app.login_enabled + self.messages = [] + self.last_read_time = time.time() + self.open() + try: + if not self.app.socket_wrap_monitor.is_running(): + self.app.socket_wrap_monitor.start() + except AttributeError: + self.app.socket_wrap_monitor = tornado.ioloop.PeriodicCallback( + self.socket_wrap_monitor_thread, 15000 + ) + self.app.socket_wrap_monitor.start() + + def socket_wrap_monitor_thread(self): + # TODO mark wrapped subs and sources separately + if len(self.subs) > 0 or len(self.sources) > 0: + for sub in list(self.subs.values()): + if time.time() - sub.last_read_time > MAX_SOCKET_WAIT: + sub.close() + for sub in list(self.sources.values()): + if time.time() - sub.last_read_time > MAX_SOCKET_WAIT: + sub.close() + else: + self.app.socket_wrap_monitor.stop() + + def broadcast_layouts(self, target_subs=None): + if target_subs is None: + target_subs = self.subs.values() + for sub in target_subs: + sub.write_message(json.dumps( + {'command': 'layout_update', 'data': self.app.layouts} + )) + + def open(self): + if self.login_enabled and not self.current_user: + print("AUTH Failed in SocketHandler") + self.close() + return + self.sid = get_rand_id() + if self not in list(self.subs.values()): + self.eid = 'main' + self.subs[self.sid] = self + logging.info('Mocking new socket: {}'.format(self.sid)) + + self.write_message( + json.dumps({'command': 'register', 'data': self.sid, + 'readonly': self.readonly})) + self.broadcast_layouts([self]) + broadcast_envs(self, [self]) + + def on_message(self, message): + logging.info('from web client: {}'.format(message)) + msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) + + cmd = msg.get('cmd') + + if self.readonly: + return + + if cmd == 'close': + if 'data' in msg and 'eid' in msg: + logging.info('closing window {}'.format(msg['data'])) + p_data = self.state[msg['eid']]['jsons'].pop(msg['data'], None) + event = { + 'event_type': 'close', + 'target': msg['data'], + 'eid': msg['eid'], + 'pane_data': p_data, + } + send_to_sources(self, event) + elif cmd == 'save': + # save localStorage window metadata + if 'data' in msg and 'eid' in msg: + msg['eid'] = escape_eid(msg['eid']) + self.state[msg['eid']] = \ + copy.deepcopy(self.state[msg['prev_eid']]) + self.state[msg['eid']]['reload'] = msg['data'] + self.eid = msg['eid'] + serialize_env(self.state, [self.eid], env_path=self.env_path) + elif cmd == 'delete_env': + if 'eid' in msg: + logging.info('closing environment {}'.format(msg['eid'])) + del self.state[msg['eid']] + if self.env_path is not None: + p = os.path.join( + self.env_path, + "{0}.json".format(msg['eid']) + ) + os.remove(p) + broadcast_envs(self) + elif cmd == 'save_layouts': + if 'data' in msg: + self.app.layouts = msg.get('data') + self.app.save_layouts() + self.broadcast_layouts() + elif cmd == 'forward_to_vis': + packet = msg.get('data') + environment = self.state[packet['eid']] + packet['pane_data'] = environment['jsons'][packet['target']] + send_to_sources(self, msg.get('data')) + elif cmd == 'layout_item_update': + eid = msg.get('eid') + win = msg.get('win') + self.state[eid]['reload'][win] = msg.get('data') + + def close(self): + if self in list(self.subs.values()): + self.subs.pop(self.sid, None) + + def write_message(self, msg): + self.messages.append(msg) + + def get_messages(self): + to_send = [] + while len(self.messages) > 0: + message = self.messages.pop() + if type(message) is dict: + # Not all messages are being formatted the same way (JSON) + # TODO investigate + message = json.dumps(message) + to_send.append(message) + self.last_read_time = time.time() + return to_send + + +class SocketWrap(BaseHandler): + def initialize(self, app): + self.state = app.state + self.subs = app.subs + self.sources = app.sources + self.port = app.port + self.env_path = app.env_path + self.login_enabled = app.login_enabled + self.app = app + + @check_auth + def post(self): + """Either write a message to the socket, or query what's there""" + # TODO formalize failure reasons + args = tornado.escape.json_decode( + tornado.escape.to_basestring(self.request.body) + ) + type = args.get('message_type') + sid = args.get('sid') + socket_wrap = self.subs.get(sid) + # ensure a wrapper still exists for this connection + if socket_wrap is None: + self.write(json.dumps({'success': False, 'reason': 'closed'})) + return + + # handle the requests + if type == 'query': + messages = socket_wrap.get_messages() + self.write(json.dumps({ + 'success': True, 'messages': messages + })) + elif type == 'send': + msg = args.get('message') + if msg is None: + self.write(json.dumps({'success': False, 'reason': 'no msg'})) + else: + socket_wrap.on_message(msg) + self.write(json.dumps({'success': True})) + else: + self.write(json.dumps({'success': False, 'reason': 'invalid'})) + + @check_auth + def get(self): + """Create a new socket wrapper for this requester, return the id""" + new_sub = ClientSocketWrapper(self.app) + self.write(json.dumps({'success': True, 'sid': new_sub.sid})) + + +# TODO refactor socket wrappers to one class +class VisSocketWrap(BaseHandler): + def initialize(self, app): + self.state = app.state + self.subs = app.subs + self.sources = app.sources + self.port = app.port + self.env_path = app.env_path + self.login_enabled = app.login_enabled + self.app = app + + @check_auth + def post(self): + """Either write a message to the socket, or query what's there""" + # TODO formalize failure reasons + args = tornado.escape.json_decode( + tornado.escape.to_basestring(self.request.body) + ) + type = args.get('message_type') + sid = args.get('sid') + + if sid is None: + new_sub = VisSocketWrapper(self.app) + self.write(json.dumps({'success': True, 'sid': new_sub.sid})) + return + + socket_wrap = self.sources.get(sid) + # ensure a wrapper still exists for this connection + if socket_wrap is None: + self.write(json.dumps({'success': False, 'reason': 'closed'})) + return + + # handle the requests + if type == 'query': + messages = socket_wrap.get_messages() + self.write(json.dumps({ + 'success': True, 'messages': messages + })) + elif type == 'send': + msg = args.get('message') + if msg is None: + self.write(json.dumps({'success': False, 'reason': 'no msg'})) + else: + socket_wrap.on_message(msg) + self.write(json.dumps({'success': True})) + else: + self.write(json.dumps({'success': False, 'reason': 'invalid'})) diff --git a/py/visdom/server/handlers/all_handlers.py b/py/visdom/server/handlers/web_handlers.py similarity index 50% rename from py/visdom/server/handlers/all_handlers.py rename to py/visdom/server/handlers/web_handlers.py index 798a5717..f73a8aca 100644 --- a/py/visdom/server/handlers/all_handlers.py +++ b/py/visdom/server/handlers/web_handlers.py @@ -6,12 +6,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -"""Server""" +""" +Handlers for the different types of web request events. Mostly handles parsing +and processing the web events themselves and interfacing with the server as +necessary, but defers underlying manipulations of the server's data to +the data_model itself. +""" # TODO fix these imports from visdom.utils.shared_utils import * from visdom.utils.server_utils import * -from visdom.server.handlers.base_handlers import * +from visdom.server.handlers.base_handlers import BaseHandler import copy import getpass import hashlib @@ -29,389 +34,17 @@ # for python 3.7 and below from collections import Mapping, Sequence -import tornado.ioloop # noqa E402: gotta install ioloop first -import tornado.web # noqa E402: gotta install ioloop first -import tornado.websocket # noqa E402: gotta install ioloop first -import tornado.escape # noqa E402: gotta install ioloop first - -LAYOUT_FILE = 'layouts.json' - -here = os.path.abspath(os.path.dirname(__file__)) -COMPACT_SEPARATORS = (',', ':') +import tornado.escape MAX_SOCKET_WAIT = 15 -# TODO Split this file up it's terrible - -class VisSocketHandler(BaseWebSocketHandler): - def initialize(self, app): - self.state = app.state - self.subs = app.subs - self.sources = app.sources - self.port = app.port - self.env_path = app.env_path - self.login_enabled = app.login_enabled - - def check_origin(self, origin): - return True - - def open(self): - if self.login_enabled and not self.current_user: - self.close() - return - self.sid = str(hex(int(time.time() * 10000000))[2:]) - if self not in list(self.sources.values()): - self.eid = 'main' - self.sources[self.sid] = self - logging.info('Opened visdom socket from ip: {}'.format( - self.request.remote_ip)) - - self.write_message( - json.dumps({'command': 'alive', 'data': 'vis_alive'})) - - def on_message(self, message): - logging.info('from visdom client: {}'.format(message)) - msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) - - cmd = msg.get('cmd') - if cmd == 'echo': - for sub in self.sources.values(): - sub.write_message(json.dumps(msg)) - - def on_close(self): - if self in list(self.sources.values()): - self.sources.pop(self.sid, None) - - -class VisSocketWrapper(): - def __init__(self, app): - self.state = app.state - self.subs = app.subs - self.sources = app.sources - self.port = app.port - self.env_path = app.env_path - self.login_enabled = app.login_enabled - self.app = app - self.messages = [] - self.last_read_time = time.time() - self.open() - try: - if not self.app.socket_wrap_monitor.is_running(): - self.app.socket_wrap_monitor.start() - except AttributeError: - self.app.socket_wrap_monitor = tornado.ioloop.PeriodicCallback( - self.socket_wrap_monitor_thread, 15000 - ) - self.app.socket_wrap_monitor.start() - - # TODO refactor the two socket wrappers into a wrapper class - def socket_wrap_monitor_thread(self): - if len(self.subs) > 0 or len(self.sources) > 0: - for sub in list(self.subs.values()): - if time.time() - sub.last_read_time > MAX_SOCKET_WAIT: - sub.close() - for sub in list(self.sources.values()): - if time.time() - sub.last_read_time > MAX_SOCKET_WAIT: - sub.close() - else: - self.app.socket_wrap_monitor.stop() - - def open(self): - if self.login_enabled and not self.current_user: - print("AUTH Failed in SocketHandler") - self.close() - return - self.sid = get_rand_id() - if self not in list(self.sources.values()): - self.eid = 'main' - self.sources[self.sid] = self - logging.info('Mocking visdom socket: {}'.format(self.sid)) - - self.write_message( - json.dumps({'command': 'alive', 'data': 'vis_alive'})) - - def on_message(self, message): - logging.info('from visdom client: {}'.format(message)) - msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) - - cmd = msg.get('cmd') - if cmd == 'echo': - for sub in self.sources.values(): - sub.write_message(json.dumps(msg)) - - def close(self): - if self in list(self.sources.values()): - self.sources.pop(self.sid, None) - - def write_message(self, msg): - self.messages.append(msg) - - def get_messages(self): - to_send = [] - while len(self.messages) > 0: - message = self.messages.pop() - if type(message) is dict: - # Not all messages are being formatted the same way (JSON) - # TODO investigate - message = json.dumps(message) - to_send.append(message) - self.last_read_time = time.time() - return to_send - - -class SocketHandler(BaseWebSocketHandler): - def initialize(self, app): - self.port = app.port - self.env_path = app.env_path - self.app = app - self.state = app.state - self.subs = app.subs - self.sources = app.sources - self.broadcast_layouts() - self.readonly = app.readonly - self.login_enabled = app.login_enabled - - def check_origin(self, origin): - return True - - def broadcast_layouts(self, target_subs=None): - if target_subs is None: - target_subs = self.subs.values() - for sub in target_subs: - sub.write_message(json.dumps( - {'command': 'layout_update', 'data': self.app.layouts} - )) - - def open(self): - if self.login_enabled and not self.current_user: - print("AUTH Failed in SocketHandler") - self.close() - return - self.sid = get_rand_id() - if self not in list(self.subs.values()): - self.eid = 'main' - self.subs[self.sid] = self - logging.info( - 'Opened new socket from ip: {}'.format(self.request.remote_ip)) - - self.write_message( - json.dumps({'command': 'register', 'data': self.sid, - 'readonly': self.readonly})) - self.broadcast_layouts([self]) - broadcast_envs(self, [self]) - - def on_message(self, message): - logging.info('from web client: {}'.format(message)) - msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) - - cmd = msg.get('cmd') - - if self.readonly: - return - - if cmd == 'close': - if 'data' in msg and 'eid' in msg: - logging.info('closing window {}'.format(msg['data'])) - p_data = self.state[msg['eid']]['jsons'].pop(msg['data'], None) - event = { - 'event_type': 'close', - 'target': msg['data'], - 'eid': msg['eid'], - 'pane_data': p_data, - } - send_to_sources(self, event) - elif cmd == 'save': - # save localStorage window metadata - if 'data' in msg and 'eid' in msg: - msg['eid'] = escape_eid(msg['eid']) - self.state[msg['eid']] = \ - copy.deepcopy(self.state[msg['prev_eid']]) - self.state[msg['eid']]['reload'] = msg['data'] - self.eid = msg['eid'] - serialize_env(self.state, [self.eid], env_path=self.env_path) - elif cmd == 'delete_env': - if 'eid' in msg: - logging.info('closing environment {}'.format(msg['eid'])) - del self.state[msg['eid']] - if self.env_path is not None: - p = os.path.join( - self.env_path, - "{0}.json".format(msg['eid']) - ) - os.remove(p) - broadcast_envs(self) - elif cmd == 'save_layouts': - if 'data' in msg: - self.app.layouts = msg.get('data') - self.app.save_layouts() - self.broadcast_layouts() - elif cmd == 'forward_to_vis': - packet = msg.get('data') - environment = self.state[packet['eid']] - if packet.get('pane_data') is not False: - packet['pane_data'] = environment['jsons'][packet['target']] - send_to_sources(self, msg.get('data')) - elif cmd == 'layout_item_update': - eid = msg.get('eid') - win = msg.get('win') - self.state[eid]['reload'][win] = msg.get('data') - elif cmd == 'pop_embeddings_pane': - packet = msg.get('data') - eid = packet['eid'] - win = packet['target'] - p = self.state[eid]['jsons'][win] - p['content']['selected'] = None - p['content']['data'] = p['old_content'].pop() - if len(p['old_content']) == 0: - p['content']['has_previous'] = False - p['contentID'] = get_rand_id() - broadcast(self, p, eid) - - def on_close(self): - if self in list(self.subs.values()): - self.subs.pop(self.sid, None) - - -# TODO condense some of the functionality between this class and the -# original SocketHandler class -class ClientSocketWrapper(): - """ - Wraps all of the socket actions in regular request handling, thus - allowing all of the same information to be sent via a polling interface - """ - def __init__(self, app): - self.port = app.port - self.env_path = app.env_path - self.app = app - self.state = app.state - self.subs = app.subs - self.sources = app.sources - self.readonly = app.readonly - self.login_enabled = app.login_enabled - self.messages = [] - self.last_read_time = time.time() - self.open() - try: - if not self.app.socket_wrap_monitor.is_running(): - self.app.socket_wrap_monitor.start() - except AttributeError: - self.app.socket_wrap_monitor = tornado.ioloop.PeriodicCallback( - self.socket_wrap_monitor_thread, 15000 - ) - self.app.socket_wrap_monitor.start() - - def socket_wrap_monitor_thread(self): - # TODO mark wrapped subs and sources separately - if len(self.subs) > 0 or len(self.sources) > 0: - for sub in list(self.subs.values()): - if time.time() - sub.last_read_time > MAX_SOCKET_WAIT: - sub.close() - for sub in list(self.sources.values()): - if time.time() - sub.last_read_time > MAX_SOCKET_WAIT: - sub.close() - else: - self.app.socket_wrap_monitor.stop() - - def broadcast_layouts(self, target_subs=None): - if target_subs is None: - target_subs = self.subs.values() - for sub in target_subs: - sub.write_message(json.dumps( - {'command': 'layout_update', 'data': self.app.layouts} - )) - - def open(self): - if self.login_enabled and not self.current_user: - print("AUTH Failed in SocketHandler") - self.close() - return - self.sid = get_rand_id() - if self not in list(self.subs.values()): - self.eid = 'main' - self.subs[self.sid] = self - logging.info('Mocking new socket: {}'.format(self.sid)) - - self.write_message( - json.dumps({'command': 'register', 'data': self.sid, - 'readonly': self.readonly})) - self.broadcast_layouts([self]) - broadcast_envs(self, [self]) - - def on_message(self, message): - logging.info('from web client: {}'.format(message)) - msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) - - cmd = msg.get('cmd') - - if self.readonly: - return - - if cmd == 'close': - if 'data' in msg and 'eid' in msg: - logging.info('closing window {}'.format(msg['data'])) - p_data = self.state[msg['eid']]['jsons'].pop(msg['data'], None) - event = { - 'event_type': 'close', - 'target': msg['data'], - 'eid': msg['eid'], - 'pane_data': p_data, - } - send_to_sources(self, event) - elif cmd == 'save': - # save localStorage window metadata - if 'data' in msg and 'eid' in msg: - msg['eid'] = escape_eid(msg['eid']) - self.state[msg['eid']] = \ - copy.deepcopy(self.state[msg['prev_eid']]) - self.state[msg['eid']]['reload'] = msg['data'] - self.eid = msg['eid'] - serialize_env(self.state, [self.eid], env_path=self.env_path) - elif cmd == 'delete_env': - if 'eid' in msg: - logging.info('closing environment {}'.format(msg['eid'])) - del self.state[msg['eid']] - if self.env_path is not None: - p = os.path.join( - self.env_path, - "{0}.json".format(msg['eid']) - ) - os.remove(p) - broadcast_envs(self) - elif cmd == 'save_layouts': - if 'data' in msg: - self.app.layouts = msg.get('data') - self.app.save_layouts() - self.broadcast_layouts() - elif cmd == 'forward_to_vis': - packet = msg.get('data') - environment = self.state[packet['eid']] - packet['pane_data'] = environment['jsons'][packet['target']] - send_to_sources(self, msg.get('data')) - elif cmd == 'layout_item_update': - eid = msg.get('eid') - win = msg.get('win') - self.state[eid]['reload'][win] = msg.get('data') - - def close(self): - if self in list(self.subs.values()): - self.subs.pop(self.sid, None) - - def write_message(self, msg): - self.messages.append(msg) - - def get_messages(self): - to_send = [] - while len(self.messages) > 0: - message = self.messages.pop() - if type(message) is dict: - # Not all messages are being formatted the same way (JSON) - # TODO investigate - message = json.dumps(message) - to_send.append(message) - self.last_read_time = time.time() - return to_send - +# TODO move the logic that actually parses environments and layouts to +# new classes in the data_model folder. +# TODO move generalized initialization logic from these handlers into the +# basehandler +# TODO abstract out any direct references to the app where possible from +# all handlers. Can instead provide accessor functions on the state? class PostHandler(BaseHandler): def initialize(self, app): self.state = app.state @@ -420,13 +53,6 @@ def initialize(self, app): self.port = app.port self.env_path = app.env_path self.login_enabled = app.login_enabled - self.handlers = { - 'update': UpdateHandler, - 'save': SaveHandler, - 'close': CloseHandler, - 'win_exists': ExistsHandler, - 'delete_env': DeleteEnvHandler, - } @check_auth def post(self): @@ -669,103 +295,6 @@ def post(self): self.wrap_func(self, args) -class SocketWrap(BaseHandler): - def initialize(self, app): - self.state = app.state - self.subs = app.subs - self.sources = app.sources - self.port = app.port - self.env_path = app.env_path - self.login_enabled = app.login_enabled - self.app = app - - @check_auth - def post(self): - """Either write a message to the socket, or query what's there""" - # TODO formalize failure reasons - args = tornado.escape.json_decode( - tornado.escape.to_basestring(self.request.body) - ) - type = args.get('message_type') - sid = args.get('sid') - socket_wrap = self.subs.get(sid) - # ensure a wrapper still exists for this connection - if socket_wrap is None: - self.write(json.dumps({'success': False, 'reason': 'closed'})) - return - - # handle the requests - if type == 'query': - messages = socket_wrap.get_messages() - self.write(json.dumps({ - 'success': True, 'messages': messages - })) - elif type == 'send': - msg = args.get('message') - if msg is None: - self.write(json.dumps({'success': False, 'reason': 'no msg'})) - else: - socket_wrap.on_message(msg) - self.write(json.dumps({'success': True})) - else: - self.write(json.dumps({'success': False, 'reason': 'invalid'})) - - @check_auth - def get(self): - """Create a new socket wrapper for this requester, return the id""" - new_sub = ClientSocketWrapper(self.app) - self.write(json.dumps({'success': True, 'sid': new_sub.sid})) - - -# TODO refactor socket wrappers to one class -class VisSocketWrap(BaseHandler): - def initialize(self, app): - self.state = app.state - self.subs = app.subs - self.sources = app.sources - self.port = app.port - self.env_path = app.env_path - self.login_enabled = app.login_enabled - self.app = app - - @check_auth - def post(self): - """Either write a message to the socket, or query what's there""" - # TODO formalize failure reasons - args = tornado.escape.json_decode( - tornado.escape.to_basestring(self.request.body) - ) - type = args.get('message_type') - sid = args.get('sid') - - if sid is None: - new_sub = VisSocketWrapper(self.app) - self.write(json.dumps({'success': True, 'sid': new_sub.sid})) - return - - socket_wrap = self.sources.get(sid) - # ensure a wrapper still exists for this connection - if socket_wrap is None: - self.write(json.dumps({'success': False, 'reason': 'closed'})) - return - - # handle the requests - if type == 'query': - messages = socket_wrap.get_messages() - self.write(json.dumps({ - 'success': True, 'messages': messages - })) - elif type == 'send': - msg = args.get('message') - if msg is None: - self.write(json.dumps({'success': False, 'reason': 'no msg'})) - else: - socket_wrap.on_message(msg) - self.write(json.dumps({'success': True})) - else: - self.write(json.dumps({'success': False, 'reason': 'invalid'})) - - class DeleteEnvHandler(BaseHandler): def initialize(self, app): self.state = app.state diff --git a/py/visdom/server/run_server.py b/py/visdom/server/run_server.py index 0f6e90c3..91296ad4 100644 --- a/py/visdom/server/run_server.py +++ b/py/visdom/server/run_server.py @@ -6,6 +6,10 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +""" +Provides simple entrypoints to set up and run the main visdom server. +""" + from visdom.server.app import Application from visdom.server.defaults import ( DEFAULT_BASE_URL, diff --git a/py/visdom/utils/server_utils.py b/py/visdom/utils/server_utils.py index bd16d6a6..ee33de48 100644 --- a/py/visdom/utils/server_utils.py +++ b/py/visdom/utils/server_utils.py @@ -60,12 +60,13 @@ def check_auth(f): Wrapper for server access methods to ensure that the access is authorized. """ - def _check_auth(app, *args, **kwargs): - app.last_access = time.time() - if app.login_enabled and not app.current_user: - app.set_status(400) + def _check_auth(handler, *args, **kwargs): + # TODO this should call a shared method of the handler + handler.last_access = time.time() + if handler.login_enabled and not handler.current_user: + handler.set_status(400) return - f(app, *args, **kwargs) + f(handler, *args, **kwargs) return _check_auth def set_cookie(value=None):