diff --git a/docs/build/doctrees/convlab.agent.algorithm.doctree b/docs/build/doctrees/convlab.agent.algorithm.doctree new file mode 100644 index 0000000..a6e2d94 Binary files /dev/null and b/docs/build/doctrees/convlab.agent.algorithm.doctree differ diff --git a/docs/build/doctrees/convlab.agent.doctree b/docs/build/doctrees/convlab.agent.doctree new file mode 100644 index 0000000..cf1fbd3 Binary files /dev/null and b/docs/build/doctrees/convlab.agent.doctree differ diff --git a/docs/build/doctrees/convlab.agent.memory.doctree b/docs/build/doctrees/convlab.agent.memory.doctree new file mode 100644 index 0000000..9575bbc Binary files /dev/null and b/docs/build/doctrees/convlab.agent.memory.doctree differ diff --git a/docs/build/doctrees/convlab.agent.net.doctree b/docs/build/doctrees/convlab.agent.net.doctree new file mode 100644 index 0000000..d6fc66d Binary files /dev/null and b/docs/build/doctrees/convlab.agent.net.doctree differ diff --git a/docs/build/doctrees/convlab.doctree b/docs/build/doctrees/convlab.doctree new file mode 100644 index 0000000..037ad9a Binary files /dev/null and b/docs/build/doctrees/convlab.doctree differ diff --git a/docs/build/doctrees/convlab.env.doctree b/docs/build/doctrees/convlab.env.doctree new file mode 100644 index 0000000..3cf9f8f Binary files /dev/null and b/docs/build/doctrees/convlab.env.doctree differ diff --git a/docs/build/doctrees/convlab.evaluator.doctree b/docs/build/doctrees/convlab.evaluator.doctree new file mode 100644 index 0000000..93bfe5b Binary files /dev/null and b/docs/build/doctrees/convlab.evaluator.doctree differ diff --git a/docs/build/doctrees/convlab.experiment.doctree b/docs/build/doctrees/convlab.experiment.doctree new file mode 100644 index 0000000..9492571 Binary files /dev/null and b/docs/build/doctrees/convlab.experiment.doctree differ diff --git a/docs/build/doctrees/convlab.human_eval.doctree b/docs/build/doctrees/convlab.human_eval.doctree new file mode 100644 index 0000000..6fbb9ed Binary files /dev/null and b/docs/build/doctrees/convlab.human_eval.doctree differ diff --git a/docs/build/doctrees/convlab.lib.doctree b/docs/build/doctrees/convlab.lib.doctree new file mode 100644 index 0000000..6dddcf1 Binary files /dev/null and b/docs/build/doctrees/convlab.lib.doctree differ diff --git a/docs/build/doctrees/convlab.modules.action_decoder.doctree b/docs/build/doctrees/convlab.modules.action_decoder.doctree new file mode 100644 index 0000000..0efb480 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.action_decoder.doctree differ diff --git a/docs/build/doctrees/convlab.modules.action_decoder.multiwoz.doctree b/docs/build/doctrees/convlab.modules.action_decoder.multiwoz.doctree new file mode 100644 index 0000000..37a6656 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.action_decoder.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.doctree b/docs/build/doctrees/convlab.modules.doctree new file mode 100644 index 0000000..fc8921f Binary files /dev/null and b/docs/build/doctrees/convlab.modules.doctree differ diff --git a/docs/build/doctrees/convlab.modules.dst.doctree b/docs/build/doctrees/convlab.modules.dst.doctree new file mode 100644 index 0000000..2d640ce Binary files /dev/null and b/docs/build/doctrees/convlab.modules.dst.doctree differ diff --git a/docs/build/doctrees/convlab.modules.dst.multiwoz.doctree b/docs/build/doctrees/convlab.modules.dst.multiwoz.doctree new file mode 100644 index 0000000..2396ed7 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.dst.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.e2e.doctree b/docs/build/doctrees/convlab.modules.e2e.doctree new file mode 100644 index 0000000..5e0814b Binary files /dev/null and b/docs/build/doctrees/convlab.modules.e2e.doctree differ diff --git a/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.doctree b/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.doctree new file mode 100644 index 0000000..d53fcb7 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.doctree differ diff --git a/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.models.doctree b/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.models.doctree new file mode 100644 index 0000000..36f343d Binary files /dev/null and b/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.models.doctree differ diff --git a/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.utils.doctree b/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.utils.doctree new file mode 100644 index 0000000..5976906 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.e2e.multiwoz.Mem2Seq.utils.doctree differ diff --git a/docs/build/doctrees/convlab.modules.e2e.multiwoz.Sequicity.doctree b/docs/build/doctrees/convlab.modules.e2e.multiwoz.Sequicity.doctree new file mode 100644 index 0000000..1b966cc Binary files /dev/null and b/docs/build/doctrees/convlab.modules.e2e.multiwoz.Sequicity.doctree differ diff --git a/docs/build/doctrees/convlab.modules.e2e.multiwoz.doctree b/docs/build/doctrees/convlab.modules.e2e.multiwoz.doctree new file mode 100644 index 0000000..77d332a Binary files /dev/null and b/docs/build/doctrees/convlab.modules.e2e.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlg.doctree b/docs/build/doctrees/convlab.modules.nlg.doctree new file mode 100644 index 0000000..11069c5 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlg.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlg.multiwoz.doctree b/docs/build/doctrees/convlab.modules.nlg.multiwoz.doctree new file mode 100644 index 0000000..559e204 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlg.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlg.multiwoz.multiwoz_template_nlg.doctree b/docs/build/doctrees/convlab.modules.nlg.multiwoz.multiwoz_template_nlg.doctree new file mode 100644 index 0000000..b7aca3d Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlg.multiwoz.multiwoz_template_nlg.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlg.multiwoz.sc_lstm.doctree b/docs/build/doctrees/convlab.modules.nlg.multiwoz.sc_lstm.doctree new file mode 100644 index 0000000..4006fbf Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlg.multiwoz.sc_lstm.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlu.doctree b/docs/build/doctrees/convlab.modules.nlu.doctree new file mode 100644 index 0000000..58a4130 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlu.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlu.multiwoz.doctree b/docs/build/doctrees/convlab.modules.nlu.multiwoz.doctree new file mode 100644 index 0000000..898c70e Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlu.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlu.multiwoz.milu.doctree b/docs/build/doctrees/convlab.modules.nlu.multiwoz.milu.doctree new file mode 100644 index 0000000..31904ba Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlu.multiwoz.milu.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlu.multiwoz.onenet.doctree b/docs/build/doctrees/convlab.modules.nlu.multiwoz.onenet.doctree new file mode 100644 index 0000000..aa711f6 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlu.multiwoz.onenet.doctree differ diff --git a/docs/build/doctrees/convlab.modules.nlu.multiwoz.svm.doctree b/docs/build/doctrees/convlab.modules.nlu.multiwoz.svm.doctree new file mode 100644 index 0000000..68d1fb2 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.nlu.multiwoz.svm.doctree differ diff --git a/docs/build/doctrees/convlab.modules.policy.doctree b/docs/build/doctrees/convlab.modules.policy.doctree new file mode 100644 index 0000000..c6250f5 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.policy.doctree differ diff --git a/docs/build/doctrees/convlab.modules.policy.system.doctree b/docs/build/doctrees/convlab.modules.policy.system.doctree new file mode 100644 index 0000000..88a3f4a Binary files /dev/null and b/docs/build/doctrees/convlab.modules.policy.system.doctree differ diff --git a/docs/build/doctrees/convlab.modules.policy.system.multiwoz.doctree b/docs/build/doctrees/convlab.modules.policy.system.multiwoz.doctree new file mode 100644 index 0000000..1cef7ed Binary files /dev/null and b/docs/build/doctrees/convlab.modules.policy.system.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.policy.system.multiwoz.vanilla_mle.doctree b/docs/build/doctrees/convlab.modules.policy.system.multiwoz.vanilla_mle.doctree new file mode 100644 index 0000000..fd060e2 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.policy.system.multiwoz.vanilla_mle.doctree differ diff --git a/docs/build/doctrees/convlab.modules.policy.user.doctree b/docs/build/doctrees/convlab.modules.policy.user.doctree new file mode 100644 index 0000000..5c1733a Binary files /dev/null and b/docs/build/doctrees/convlab.modules.policy.user.doctree differ diff --git a/docs/build/doctrees/convlab.modules.policy.user.multiwoz.doctree b/docs/build/doctrees/convlab.modules.policy.user.multiwoz.doctree new file mode 100644 index 0000000..92099fe Binary files /dev/null and b/docs/build/doctrees/convlab.modules.policy.user.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.state_encoder.doctree b/docs/build/doctrees/convlab.modules.state_encoder.doctree new file mode 100644 index 0000000..4cbda1f Binary files /dev/null and b/docs/build/doctrees/convlab.modules.state_encoder.doctree differ diff --git a/docs/build/doctrees/convlab.modules.state_encoder.multiwoz.doctree b/docs/build/doctrees/convlab.modules.state_encoder.multiwoz.doctree new file mode 100644 index 0000000..a82a53f Binary files /dev/null and b/docs/build/doctrees/convlab.modules.state_encoder.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.usr.doctree b/docs/build/doctrees/convlab.modules.usr.doctree new file mode 100644 index 0000000..12e53fe Binary files /dev/null and b/docs/build/doctrees/convlab.modules.usr.doctree differ diff --git a/docs/build/doctrees/convlab.modules.usr.multiwoz.doctree b/docs/build/doctrees/convlab.modules.usr.multiwoz.doctree new file mode 100644 index 0000000..9db8798 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.usr.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.util.doctree b/docs/build/doctrees/convlab.modules.util.doctree new file mode 100644 index 0000000..1358f12 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.util.doctree differ diff --git a/docs/build/doctrees/convlab.modules.word_dst.doctree b/docs/build/doctrees/convlab.modules.word_dst.doctree new file mode 100644 index 0000000..235fa78 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.word_dst.doctree differ diff --git a/docs/build/doctrees/convlab.modules.word_dst.multiwoz.doctree b/docs/build/doctrees/convlab.modules.word_dst.multiwoz.doctree new file mode 100644 index 0000000..d727d87 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.word_dst.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.word_policy.doctree b/docs/build/doctrees/convlab.modules.word_policy.doctree new file mode 100644 index 0000000..ff38b1e Binary files /dev/null and b/docs/build/doctrees/convlab.modules.word_policy.doctree differ diff --git a/docs/build/doctrees/convlab.modules.word_policy.multiwoz.doctree b/docs/build/doctrees/convlab.modules.word_policy.multiwoz.doctree new file mode 100644 index 0000000..11cf559 Binary files /dev/null and b/docs/build/doctrees/convlab.modules.word_policy.multiwoz.doctree differ diff --git a/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.doctree b/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.doctree new file mode 100644 index 0000000..61d4ccd Binary files /dev/null and b/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.doctree differ diff --git a/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.model.doctree b/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.model.doctree new file mode 100644 index 0000000..94974bb Binary files /dev/null and b/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.model.doctree differ diff --git a/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.utils.doctree b/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.utils.doctree new file mode 100644 index 0000000..166c2bd Binary files /dev/null and b/docs/build/doctrees/convlab.modules.word_policy.multiwoz.mdrg.utils.doctree differ diff --git a/docs/build/doctrees/convlab.spec.doctree b/docs/build/doctrees/convlab.spec.doctree new file mode 100644 index 0000000..1331e14 Binary files /dev/null and b/docs/build/doctrees/convlab.spec.doctree differ diff --git a/docs/build/doctrees/environment.pickle b/docs/build/doctrees/environment.pickle index 9372712..90874e2 100644 Binary files a/docs/build/doctrees/environment.pickle and b/docs/build/doctrees/environment.pickle differ diff --git a/docs/build/doctrees/index.doctree b/docs/build/doctrees/index.doctree index bd734de..ec56b35 100644 Binary files a/docs/build/doctrees/index.doctree and b/docs/build/doctrees/index.doctree differ diff --git a/docs/build/doctrees/modules.doctree b/docs/build/doctrees/modules.doctree index 0498735..ec1b5be 100644 Binary files a/docs/build/doctrees/modules.doctree and b/docs/build/doctrees/modules.doctree differ diff --git a/docs/build/doctrees/nlu.doctree b/docs/build/doctrees/nlu.doctree deleted file mode 100644 index ba87fca..0000000 Binary files a/docs/build/doctrees/nlu.doctree and /dev/null differ diff --git a/docs/build/doctrees/tasktk.dialog_agent.doctree b/docs/build/doctrees/tasktk.dialog_agent.doctree deleted file mode 100644 index 4fa3c05..0000000 Binary files a/docs/build/doctrees/tasktk.dialog_agent.doctree and /dev/null differ diff --git a/docs/build/doctrees/tasktk.doctree b/docs/build/doctrees/tasktk.doctree deleted file mode 100644 index d4af737..0000000 Binary files a/docs/build/doctrees/tasktk.doctree and /dev/null differ diff --git a/docs/build/doctrees/tasktk.dst.doctree b/docs/build/doctrees/tasktk.dst.doctree deleted file mode 100644 index 04f164f..0000000 Binary files a/docs/build/doctrees/tasktk.dst.doctree and /dev/null differ diff --git a/docs/build/doctrees/tasktk.nlg.doctree b/docs/build/doctrees/tasktk.nlg.doctree deleted file mode 100644 index beba308..0000000 Binary files a/docs/build/doctrees/tasktk.nlg.doctree and /dev/null differ diff --git a/docs/build/doctrees/tasktk.nlu.doctree b/docs/build/doctrees/tasktk.nlu.doctree deleted file mode 100644 index db75664..0000000 Binary files a/docs/build/doctrees/tasktk.nlu.doctree and /dev/null differ diff --git a/docs/build/doctrees/tasktk.policy.doctree b/docs/build/doctrees/tasktk.policy.doctree deleted file mode 100644 index c962b96..0000000 Binary files a/docs/build/doctrees/tasktk.policy.doctree and /dev/null differ diff --git a/docs/build/doctrees/tasktk.usr.doctree b/docs/build/doctrees/tasktk.usr.doctree deleted file mode 100644 index c1450d3..0000000 Binary files a/docs/build/doctrees/tasktk.usr.doctree and /dev/null differ diff --git a/docs/build/doctrees/tasktk.util.doctree b/docs/build/doctrees/tasktk.util.doctree deleted file mode 100644 index db12148..0000000 Binary files a/docs/build/doctrees/tasktk.util.doctree and /dev/null differ diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 9cdd219..f96be4f 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 9f94c3cb63bd71b6e2e5402fa59e1636 +config: 8d4781a376070f84a925d5268ce15f35 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/.nojekyll b/docs/build/html/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/docs/build/html/_modules/convlab/agent.html b/docs/build/html/_modules/convlab/agent.html new file mode 100644 index 0000000..4c2e0d1 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent.html @@ -0,0 +1,581 @@ + + + + + + + + + + + convlab.agent — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+from copy import deepcopy
+
+# The agent module
+import numpy as np
+import pandas as pd
+import pydash as ps
+import torch
+
+from convlab.agent import algorithm, memory
+from convlab.agent.algorithm import policy_util
+from convlab.agent.net import net_util
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+from convlab.modules import nlu, dst, word_dst, nlg, state_encoder, action_decoder
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class Agent: + ''' + Agent abstraction; implements the API to interface with Env in SLM Lab + Contains algorithm, memory, body + ''' + + def __init__(self, spec, body, a=None, global_nets=None): + self.spec = spec + self.a = a or 0 # for multi-agent + self.agent_spec = spec['agent'][self.a] + self.name = self.agent_spec['name'] + assert not ps.is_list(global_nets), f'single agent global_nets must be a dict, got {global_nets}' + # set components + self.body = body + body.agent = self + MemoryClass = getattr(memory, ps.get(self.agent_spec, 'memory.name')) + self.body.memory = MemoryClass(self.agent_spec['memory'], self.body) + AlgorithmClass = getattr(algorithm, ps.get(self.agent_spec, 'algorithm.name')) + self.algorithm = AlgorithmClass(self, global_nets) + + logger.info(util.self_desc(self)) + +
[docs] @lab_api + def act(self, state): + '''Standard act method from algorithm.''' + with torch.no_grad(): # for efficiency, only calc grad in algorithm.train + action = self.algorithm.act(state) + return action
+ +
[docs] @lab_api + def update(self, state, action, reward, next_state, done): + '''Update per timestep after env transitions, e.g. memory, algorithm, update agent params, train net''' + self.body.update(state, action, reward, next_state, done) + if util.in_eval_lab_modes(): # eval does not update agent for training + return + self.body.memory.update(state, action, reward, next_state, done) + loss = self.algorithm.train() + if not np.isnan(loss): # set for log_summary() + self.body.loss = loss + explore_var = self.algorithm.update() + return loss, explore_var
+ +
[docs] @lab_api + def save(self, ckpt=None): + '''Save agent''' + if util.in_eval_lab_modes(): # eval does not save new models + return + self.algorithm.save(ckpt=ckpt)
+ +
[docs] @lab_api + def close(self): + '''Close and cleanup agent at the end of a session, e.g. save model''' + self.save()
+ + +
[docs]class DialogAgent(Agent): + ''' + Class for all Agents. + Standardizes the Agent design to work in Lab. + Access Envs properties by: Agents - AgentSpace - AEBSpace - EnvSpace - Envs + ''' + def __init__(self, spec, body, a=None, global_nets=None): + self.spec = spec + self.a = a or 0 # for compatibility with agent_space + self.agent_spec = spec['agent'][self.a] + self.name = self.agent_spec['name'] + assert not ps.is_list(global_nets), f'single agent global_nets must be a dict, got {global_nets}' + self.nlu = None + if 'nlu' in self.agent_spec: + params = deepcopy(ps.get(self.agent_spec, 'nlu')) + NluClass = getattr(nlu, params.pop('name')) + self.nlu = NluClass(**params) + self.dst = None + if 'dst' in self.agent_spec: + params = deepcopy(ps.get(self.agent_spec, 'dst')) + DstClass = getattr(dst, params.pop('name')) + self.dst = DstClass(**params) + if 'word_dst' in self.agent_spec: + params = deepcopy(ps.get(self.agent_spec, 'word_dst')) + DstClass = getattr(word_dst, params.pop('name')) + self.dst = DstClass(**params) + self.state_encoder = None + if 'state_encoder' in self.agent_spec: + params = deepcopy(ps.get(self.agent_spec, 'state_encoder')) + StateEncoderClass = getattr(state_encoder, params.pop('name')) + self.state_encoder = StateEncoderClass(**params) + self.action_decoder = None + if 'action_decoder' in self.agent_spec: + params = deepcopy(ps.get(self.agent_spec, 'action_decoder')) + ActionDecoderClass = getattr(action_decoder, params.pop('name')) + self.action_decoder = ActionDecoderClass(**params) + self.nlg = None + if 'nlg' in self.agent_spec: + params = deepcopy(ps.get(self.agent_spec, 'nlg')) + NlgClass = getattr(nlg, params.pop('name')) + self.nlg = NlgClass(**params) + self.body = body + body.agent = self + AlgorithmClass = getattr(algorithm, ps.get(self.agent_spec, 'algorithm.name')) + self.algorithm = AlgorithmClass(self, global_nets) + if ps.get(self.agent_spec, 'memory'): + MemoryClass = getattr(memory, ps.get(self.agent_spec, 'memory.name')) + self.body.memory = MemoryClass(self.agent_spec['memory'], self.body) + self.warmup_epi = ps.get(self.agent_spec, 'algorithm.warmup_epi') or -1 + self.body.state, self.body.encoded_state, self.body.action = None, None, None + logger.info(util.self_desc(self)) + +
[docs] @lab_api + def reset(self, obs): + '''Do agent reset per session, such as memory pointer''' + logger.debug(f'Agent {self.a} reset') + if self.dst: + self.dst.init_session() + if hasattr(self.algorithm, "reset"): # This is mainly for external policies that may need to reset its state. + self.algorithm.reset() + + input_act, state, encoded_state = self.state_update(obs, "null") # "null" action to be compatible with MDBT + + self.body.state, self.body.encoded_state = state, encoded_state
+ +
[docs] @lab_api + def act(self, obs): + '''Standard act method from algorithm.''' + action = self.algorithm.act(self.body.encoded_state) + self.body.action = action + + output_act, decoded_action = self.action_decode(action, self.body.state) + + logger.act(f'System action: {action}') + logger.nl(f'System utterance: {decoded_action}') + + return decoded_action
+ +
[docs] def state_update(self, obs, action): + # update history + if self.dst: + self.dst.state['history'].append([str(action)]) + + # NLU parsing + input_act = self.nlu.parse(obs, sum(self.dst.state['history'], []) if self.dst else []) if self.nlu else obs + + # state tracking + state = self.dst.update(input_act) if self.dst else input_act + + # update history + if self.dst: + self.dst.state['history'][-1].append(str(obs)) + + # encode state + encoded_state = self.state_encoder.encode(state) if self.state_encoder else state + + if self.nlu and self.dst: + self.dst.state['user_action'] = input_act + elif self.dst and not isinstance(self.dst, word_dst.MDBTTracker): # for act-in act-out agent + self.dst.state['user_action'] = obs + + logger.nl(f'User utterance: {obs}') + logger.act(f'Inferred user action: {input_act}') + logger.state(f'Dialog state: {state}') + + return input_act, state, encoded_state
+ +
[docs] def action_decode(self, action, state): + output_act = self.action_decoder.decode(action, state) if self.action_decoder else action + decoded_action = self.nlg.generate(output_act) if self.nlg else output_act + return output_act, decoded_action
+ +
[docs] def get_env(self): + return self.body.eval_env if util.in_eval_lab_modes() else self.body.env
+ +
[docs] @lab_api + def update(self, obs, action, reward, next_obs, done): + '''Update per timestep after env transitions, e.g. memory, algorithm, update agent params, train net''' + # update state + input_act, next_state, encoded_state = self.state_update(next_obs, action) + + # update body + self.body.update(self.body.state, action, reward, next_state, done) + + # update memory + if util.in_eval_lab_modes() or self.algorithm.__class__.__name__ == 'ExternalPolicy': # eval does not update agent for training + self.body.state, self.body.encoded_state = next_state, encoded_state + return + + if not hasattr(self.body, 'warmup_memory') or self.body.env.clock.epi > self.warmup_epi: + self.body.memory.update(self.body.encoded_state, self.body.action, reward, encoded_state, done) + else: + self.body.warmup_memory.update(self.body.encoded_state, self.body.action, reward, encoded_state, done) + + # update body + self.body.state, self.body.encoded_state = next_state, encoded_state + + # train algorithm + loss = self.algorithm.train() + if not np.isnan(loss): # set for log_summary() + self.body.loss = loss + explore_var = self.algorithm.update() + + return loss, explore_var
+ +
[docs] @lab_api + def save(self, ckpt=None): + '''Save agent''' + if self.algorithm.__class__.__name__ == 'ExternalPolicy': + return + if util.in_eval_lab_modes(): + # eval does not save new models + return + self.algorithm.save(ckpt=ckpt)
+ +
[docs] @lab_api + def close(self): + '''Close and cleanup agent at the end of a session, e.g. save model''' + self.save()
+ + +
[docs]class Body: + ''' + Body of an agent inside an environment, it: + - enables the automatic dimension inference for constructing network input/output + - acts as reference bridge between agent and environment (useful for multi-agent, multi-env) + - acts as non-gradient variable storage for monitoring and analysis + ''' + + def __init__(self, env, agent_spec, aeb=(0, 0, 0)): + # essential reference variables + self.agent = None # set later + self.env = env + self.aeb = aeb + self.a, self.e, self.b = aeb + + # variables set during init_algorithm_params + self.explore_var = np.nan # action exploration: epsilon or tau + self.entropy_coef = np.nan # entropy for exploration + + # debugging/logging variables, set in train or loss function + self.loss = np.nan + self.mean_entropy = np.nan + self.mean_grad_norm = np.nan + + self.ckpt_total_reward = np.nan + self.total_reward = 0 # init to 0, but dont ckpt before end of an epi + self.total_reward_ma = np.nan + self.ma_window = 100 + # store current and best reward_ma for model checkpointing and early termination if all the environments are solved + self.best_reward_ma = -np.inf + self.eval_reward_ma = np.nan + + # dataframes to track data for analysis.analyze_session + # track training data per episode + self.train_df = pd.DataFrame(columns=[ + 'epi', 't', 'wall_t', 'opt_step', 'frame', 'fps', 'total_reward', 'avg_return', 'avg_len', 'avg_success', 'loss', 'lr', + 'explore_var', 'entropy_coef', 'entropy', 'grad_norm']) + # track eval data within run_eval. the same as train_df except for reward + self.eval_df = self.train_df.copy() + + # the specific agent-env interface variables for a body + self.observation_space = self.env.observation_space + self.action_space = self.env.action_space + self.observable_dim = self.env.observable_dim + self.state_dim = self.observable_dim['state'] + self.action_dim = self.env.action_dim + self.is_discrete = self.env.is_discrete + # set the ActionPD class for sampling action + self.action_type = policy_util.get_action_type(self.action_space) + self.action_pdtype = agent_spec[self.a]['algorithm'].get('action_pdtype') + if self.action_pdtype in (None, 'default'): + self.action_pdtype = policy_util.ACTION_PDS[self.action_type][0] + self.ActionPD = policy_util.get_action_pd_cls(self.action_pdtype, self.action_type) + +
[docs] def update(self, state, action, reward, next_state, done): + '''Interface update method for body at agent.update()''' + if hasattr(self.env.u_env, 'raw_reward'): # use raw_reward if reward is preprocessed + reward = self.env.u_env.raw_reward + if self.ckpt_total_reward is np.nan: # init + self.ckpt_total_reward = reward + else: # reset on epi_start, else keep adding. generalized for vec env + self.ckpt_total_reward = self.ckpt_total_reward * (1 - self.epi_start) + reward + self.total_reward = done * self.ckpt_total_reward + (1 - done) * self.total_reward + self.epi_start = done
+ + def __str__(self): + return f'body: {util.to_json(util.get_class_attr(self))}' + +
[docs] def calc_df_row(self, env): + '''Calculate a row for updating train_df or eval_df.''' + frame = self.env.clock.get('frame') + wall_t = env.clock.get_elapsed_wall_t() + fps = 0 if wall_t == 0 else frame / wall_t + + # update debugging variables + if net_util.to_check_train_step(): + grad_norms = net_util.get_grad_norms(self.agent.algorithm) + self.mean_grad_norm = np.nan if ps.is_empty(grad_norms) else np.mean(grad_norms) + + row = pd.Series({ + # epi and frame are always measured from training env + 'epi': self.env.clock.get('epi'), + # t and reward are measured from a given env or eval_env + 't': env.clock.get('t'), + 'wall_t': wall_t, + 'opt_step': self.env.clock.get('opt_step'), + 'frame': frame, + 'fps': fps, + 'total_reward': np.nanmean(self.total_reward), # guard for vec env + 'avg_return': np.nan, # update outside + 'avg_len': np.nan, # update outside + 'avg_success': np.nan, # update outside + 'loss': self.loss, + 'lr': self.get_mean_lr(), + 'explore_var': self.explore_var, + 'entropy_coef': self.entropy_coef if hasattr(self, 'entropy_coef') else np.nan, + 'entropy': self.mean_entropy, + 'grad_norm': self.mean_grad_norm, + }, dtype=np.float32) + assert all(col in self.train_df.columns for col in row.index), f'Mismatched row keys: {row.index} vs df columns {self.train_df.columns}' + return row
+ +
[docs] def train_ckpt(self): + '''Checkpoint to update body.train_df data''' + row = self.calc_df_row(self.env) + # append efficiently to df + self.train_df.loc[len(self.train_df)] = row + # update current reward_ma + self.total_reward_ma = self.train_df[-self.ma_window:]['total_reward'].mean() + self.train_df.iloc[-1]['avg_return'] = self.total_reward_ma
+ +
[docs] def eval_ckpt(self, eval_env, avg_return, avg_len, avg_success): + '''Checkpoint to update body.eval_df data''' + row = self.calc_df_row(eval_env) + # append efficiently to df + self.eval_df.loc[len(self.eval_df)] = row + # update current reward_ma + self.eval_reward_ma = avg_return + self.eval_df.iloc[-1]['avg_return'] = avg_return + self.eval_df.iloc[-1]['avg_len'] = avg_len + self.eval_df.iloc[-1]['avg_success'] = avg_success
+ +
[docs] def get_mean_lr(self): + '''Gets the average current learning rate of the algorithm's nets.''' + if not hasattr(self.agent.algorithm, 'net_names'): + return np.nan + lrs = [] + for attr, obj in self.agent.algorithm.__dict__.items(): + if attr.endswith('lr_scheduler'): + lrs.append(obj.get_lr()) + return np.mean(lrs)
+ +
[docs] def get_log_prefix(self): + '''Get the prefix for logging''' + spec = self.agent.spec + spec_name = spec['name'] + trial_index = spec['meta']['trial'] + session_index = spec['meta']['session'] + prefix = f'Trial {trial_index} session {session_index} {spec_name}_t{trial_index}_s{session_index}' + return prefix
+ +
[docs] def log_metrics(self, metrics, df_mode): + '''Log session metrics''' + prefix = self.get_log_prefix() + row_str = ' '.join([f'{k}: {v:g}' for k, v in metrics.items()]) + msg = f'{prefix} [{df_mode}_df metrics] {row_str}' + logger.info(msg)
+ +
[docs] def log_summary(self, df_mode): + ''' + Log the summary for this body when its environment is done + @param str:df_mode 'train' or 'eval' + ''' + prefix = self.get_log_prefix() + df = getattr(self, f'{df_mode}_df') + last_row = df.iloc[-1] + row_str = ' '.join([f'{k}: {v:g}' for k, v in last_row.items()]) + msg = f'{prefix} [{df_mode}_df] {row_str}' + logger.info(msg)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/actor_critic.html b/docs/build/html/_modules/convlab/agent/algorithm/actor_critic.html new file mode 100644 index 0000000..85b25a7 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/actor_critic.html @@ -0,0 +1,499 @@ + + + + + + + + + + + convlab.agent.algorithm.actor_critic — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.actor_critic

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+import pydash as ps
+import torch
+
+from convlab.agent import net
+from convlab.agent.algorithm import policy_util
+from convlab.agent.algorithm.reinforce import Reinforce
+from convlab.agent.net import net_util
+from convlab.lib import logger, math_util, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class ActorCritic(Reinforce): + ''' + Implementation of single threaded Advantage Actor Critic + Original paper: "Asynchronous Methods for Deep Reinforcement Learning" + https://arxiv.org/abs/1602.01783 + Algorithm specific spec param: + memory.name: batch (through OnPolicyBatchReplay memory class) or episodic through (OnPolicyReplay memory class) + lam: if not null, used as the lambda value of generalized advantage estimation (GAE) introduced in "High-Dimensional Continuous Control Using Generalized Advantage Estimation https://arxiv.org/abs/1506.02438. This lambda controls the bias variance tradeoff for GAE. Floating point value between 0 and 1. Lower values correspond to more bias, less variance. Higher values to more variance, less bias. Algorithm becomes A2C(GAE). + num_step_returns: if lam is null and this is not null, specifies the number of steps for N-step returns from "Asynchronous Methods for Deep Reinforcement Learning". The algorithm becomes A2C(Nstep). + If both lam and num_step_returns are null, use the default TD error. Then the algorithm stays as AC. + net.type: whether the actor and critic should share params (e.g. through 'MLPNetShared') or have separate params (e.g. through 'MLPNetSeparate'). If param sharing is used then there is also the option to control the weight given to the policy and value components of the loss function through 'policy_loss_coef' and 'val_loss_coef' + Algorithm - separate actor and critic: + Repeat: + 1. Collect k examples + 2. Train the critic network using these examples + 3. Calculate the advantage of each example using the critic + 4. Multiply the advantage by the negative of log probability of the action taken, and sum all the values. This is the policy loss. + 5. Calculate the gradient the parameters of the actor network with respect to the policy loss + 6. Update the actor network parameters using the gradient + Algorithm - shared parameters: + Repeat: + 1. Collect k examples + 2. Calculate the target for each example for the critic + 3. Compute current estimate of state-value for each example using the critic + 4. Calculate the critic loss using a regression loss (e.g. square loss) between the target and estimate of the state-value for each example + 5. Calculate the advantage of each example using the rewards and critic + 6. Multiply the advantage by the negative of log probability of the action taken, and sum all the values. This is the policy loss. + 7. Compute the total loss by summing the value and policy lossses + 8. Calculate the gradient of the parameters of shared network with respect to the total loss + 9. Update the shared network parameters using the gradient + + e.g. algorithm_spec + "algorithm": { + "name": "ActorCritic", + "action_pdtype": "default", + "action_policy": "default", + "explore_var_spec": null, + "gamma": 0.99, + "lam": 1.0, + "num_step_returns": 100, + "entropy_coef_spec": { + "name": "linear_decay", + "start_val": 0.01, + "end_val": 0.001, + "start_step": 100, + "end_step": 5000, + }, + "policy_loss_coef": 1.0, + "val_loss_coef": 0.01, + "training_frequency": 1, + } + + e.g. special net_spec param "shared" to share/separate Actor/Critic + "net": { + "type": "MLPNet", + "shared": true, + ... + ''' + +
[docs] @lab_api + def init_algorithm_params(self): + '''Initialize other algorithm parameters''' + # set default + util.set_attr(self, dict( + action_pdtype='default', + action_policy='default', + explore_var_spec=None, + entropy_coef_spec=None, + policy_loss_coef=1.0, + val_loss_coef=1.0, + )) + util.set_attr(self, self.algorithm_spec, [ + 'action_pdtype', + 'action_policy', + # theoretically, AC does not have policy update; but in this implementation we have such option + 'explore_var_spec', + 'gamma', # the discount factor + 'lam', + 'num_step_returns', + 'entropy_coef_spec', + 'policy_loss_coef', + 'val_loss_coef', + 'training_frequency', + ]) + self.to_train = 0 + self.action_policy = getattr(policy_util, self.action_policy) + self.explore_var_scheduler = policy_util.VarScheduler(self.explore_var_spec) + self.body.explore_var = self.explore_var_scheduler.start_val + if self.entropy_coef_spec is not None: + self.entropy_coef_scheduler = policy_util.VarScheduler(self.entropy_coef_spec) + self.body.entropy_coef = self.entropy_coef_scheduler.start_val + # Select appropriate methods to calculate advs and v_targets for training + if self.lam is not None: + self.calc_advs_v_targets = self.calc_gae_advs_v_targets + elif self.num_step_returns is not None: + self.calc_advs_v_targets = self.calc_nstep_advs_v_targets + else: + self.calc_advs_v_targets = self.calc_ret_advs_v_targets
+ +
[docs] @lab_api + def init_nets(self, global_nets=None): + ''' + Initialize the neural networks used to learn the actor and critic from the spec + Below we automatically select an appropriate net based on two different conditions + 1. If the action space is discrete or continuous action + - Networks for continuous action spaces have two heads and return two values, the first is a tensor containing the mean of the action policy, the second is a tensor containing the std deviation of the action policy. The distribution is assumed to be a Gaussian (Normal) distribution. + - Networks for discrete action spaces have a single head and return the logits for a categorical probability distribution over the discrete actions + 2. If the actor and critic are separate or share weights + - If the networks share weights then the single network returns a list. + - Continuous action spaces: The return list contains 3 elements: The first element contains the mean output for the actor (policy), the second element the std dev of the policy, and the third element is the state-value estimated by the network. + - Discrete action spaces: The return list contains 2 element. The first element is a tensor containing the logits for a categorical probability distribution over the actions. The second element contains the state-value estimated by the network. + 3. If the network type is feedforward, convolutional, or recurrent + - Feedforward and convolutional networks take a single state as input and require an OnPolicyReplay or OnPolicyBatchReplay memory + - Recurrent networks take n states as input and require env spec "frame_op": "concat", "frame_op_len": seq_len + ''' + assert 'shared' in self.net_spec, 'Specify "shared" for ActorCritic network in net_spec' + self.shared = self.net_spec['shared'] + + # create actor/critic specific specs + actor_net_spec = self.net_spec.copy() + critic_net_spec = self.net_spec.copy() + for k in self.net_spec: + if 'actor_' in k: + actor_net_spec[k.replace('actor_', '')] = actor_net_spec.pop(k) + critic_net_spec.pop(k) + if 'critic_' in k: + critic_net_spec[k.replace('critic_', '')] = critic_net_spec.pop(k) + actor_net_spec.pop(k) + if critic_net_spec['use_same_optim']: + critic_net_spec = actor_net_spec + + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body, add_critic=self.shared) + # main actor network, may contain out_dim self.shared == True + NetClass = getattr(net, actor_net_spec['type']) + self.net = NetClass(actor_net_spec, in_dim, out_dim) + self.net_names = ['net'] + if not self.shared: # add separate network for critic + critic_out_dim = 1 + CriticNetClass = getattr(net, critic_net_spec['type']) + self.critic_net = CriticNetClass(critic_net_spec, in_dim, critic_out_dim) + self.net_names.append('critic_net') + # init net optimizer and its lr scheduler + self.optim = net_util.get_optim(self.net, self.net.optim_spec) + self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + if not self.shared: + self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec) + self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) + self.post_init_nets()
+ +
[docs] @lab_api + def calc_pdparam(self, x, net=None): + ''' + The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist. + ''' + out = super().calc_pdparam(x, net=net) + if self.shared: + assert ps.is_list(out), f'Shared output should be a list [pdparam, v]' + if len(out) == 2: # single policy + pdparam = out[0] + else: # multiple-task policies, still assumes 1 value + pdparam = out[:-1] + self.v_pred = out[-1].view(-1) # cache for loss calc to prevent double-pass + else: # out is pdparam + pdparam = out + return pdparam
+ +
[docs] def calc_v(self, x, net=None, use_cache=True): + ''' + Forward-pass to calculate the predicted state-value from critic_net. + ''' + if self.shared: # output: policy, value + if use_cache: # uses cache from calc_pdparam to prevent double-pass + v_pred = self.v_pred + else: + net = self.net if net is None else net + v_pred = net(x)[-1].view(-1) + else: + net = self.critic_net if net is None else net + v_pred = net(x).view(-1) + return v_pred
+ +
[docs] def calc_pdparam_v(self, batch): + '''Efficiently forward to get pdparam and v by batch for loss computation''' + states = batch['states'] + if self.body.env.is_venv: + states = math_util.venv_unpack(states) + pdparam = self.calc_pdparam(states) + v_pred = self.calc_v(states) # uses self.v_pred from calc_pdparam if self.shared + return pdparam, v_pred
+ +
[docs] def calc_ret_advs_v_targets(self, batch, v_preds): + '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets''' + v_preds = v_preds.detach() # adv does not accumulate grad + if self.body.env.is_venv: + v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) + rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) + advs = rets - v_preds + v_targets = rets + if self.body.env.is_venv: + advs = math_util.venv_unpack(advs) + v_targets = math_util.venv_unpack(v_targets) + logger.debug(f'advs: {advs}\nv_targets: {v_targets}') + return advs, v_targets
+ +
[docs] def calc_nstep_advs_v_targets(self, batch, v_preds): + ''' + Calculate N-step returns, and advs = nstep_rets - v_preds, v_targets = nstep_rets + See n-step advantage under http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_5_actor_critic_pdf.pdf + ''' + next_states = batch['next_states'][-1] + if not self.body.env.is_venv: + next_states = next_states.unsqueeze(dim=0) + with torch.no_grad(): + next_v_pred = self.calc_v(next_states, use_cache=False) + v_preds = v_preds.detach() # adv does not accumulate grad + if self.body.env.is_venv: + v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) + nstep_rets = math_util.calc_nstep_returns(batch['rewards'], batch['dones'], next_v_pred, self.gamma, self.num_step_returns) + advs = nstep_rets - v_preds + v_targets = nstep_rets + if self.body.env.is_venv: + advs = math_util.venv_unpack(advs) + v_targets = math_util.venv_unpack(v_targets) + logger.debug(f'advs: {advs}\nv_targets: {v_targets}') + return advs, v_targets
+ +
[docs] def calc_gae_advs_v_targets(self, batch, v_preds): + ''' + Calculate GAE, and advs = GAE, v_targets = advs + v_preds + See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf + ''' + next_states = batch['next_states'][-1] + if not self.body.env.is_venv: + next_states = next_states.unsqueeze(dim=0) + with torch.no_grad(): + next_v_pred = self.calc_v(next_states, use_cache=False) + v_preds = v_preds.detach() # adv does not accumulate grad + if self.body.env.is_venv: + v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) + next_v_pred = next_v_pred.unsqueeze(dim=0) + v_preds_all = torch.cat((v_preds, next_v_pred), dim=0) + advs = math_util.calc_gaes(batch['rewards'], batch['dones'], v_preds_all, self.gamma, self.lam) + v_targets = advs + v_preds + advs = math_util.standardize(advs) # standardize only for advs, not v_targets + if self.body.env.is_venv: + advs = math_util.venv_unpack(advs) + v_targets = math_util.venv_unpack(v_targets) + logger.debug(f'advs: {advs}\nv_targets: {v_targets}') + return advs, v_targets
+ +
[docs] def calc_policy_loss(self, batch, pdparams, advs): + '''Calculate the actor's policy loss''' + return super().calc_policy_loss(batch, pdparams, advs)
+ +
[docs] def calc_val_loss(self, v_preds, v_targets): + '''Calculate the critic's value loss''' + assert v_preds.shape == v_targets.shape, f'{v_preds.shape} != {v_targets.shape}' + val_loss = self.val_loss_coef * self.net.loss_fn(v_preds, v_targets) + logger.debug(f'Critic value loss: {val_loss:g}') + return val_loss
+ +
[docs] def train(self): + '''Train actor critic by computing the loss in batch efficiently''' + if util.in_eval_lab_modes(): + return np.nan + clock = self.body.env.clock + if self.to_train == 1: + batch = self.sample() + clock.set_batch_size(len(batch)) + pdparams, v_preds = self.calc_pdparam_v(batch) + advs, v_targets = self.calc_advs_v_targets(batch, v_preds) + policy_loss = self.calc_policy_loss(batch, pdparams, advs) # from actor + val_loss = self.calc_val_loss(v_preds, v_targets) # from critic + if self.shared: # shared network + loss = policy_loss + val_loss + self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + else: + self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) + loss = policy_loss + val_loss + # reset + self.to_train = 0 + logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}') + return loss.item() + else: + return np.nan
+ +
[docs] @lab_api + def update(self): + self.body.explore_var = self.explore_var_scheduler.update(self, self.body.env.clock) + if self.entropy_coef_spec is not None: + self.body.entropy_coef = self.entropy_coef_scheduler.update(self, self.body.env.clock) + return self.body.explore_var
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/base.html b/docs/build/html/_modules/convlab/agent/algorithm/base.html new file mode 100644 index 0000000..698feb3 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/base.html @@ -0,0 +1,366 @@ + + + + + + + + + + + convlab.agent.algorithm.base — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.base

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+import pydash as ps
+
+from convlab.agent.net import net_util
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+
[docs]class Algorithm(ABC): + ''' + Abstract class ancestor to all Algorithms, + specifies the necessary design blueprint for agent to work in Lab. + Mostly, implement just the abstract methods and properties. + ''' + + def __init__(self, agent, global_nets=None): + ''' + @param {*} agent is the container for algorithm and related components, and interfaces with env. + ''' + self.agent = agent + self.algorithm_spec = agent.agent_spec['algorithm'] + self.name = self.algorithm_spec['name'] + self.net_spec = agent.agent_spec.get('net', None) + if ps.get(agent.agent_spec, 'memory'): + self.memory_spec = agent.agent_spec['memory'] + self.body = self.agent.body + self.init_algorithm_params() + self.init_nets(global_nets) + logger.info(util.self_desc(self)) + +
[docs] @abstractmethod + @lab_api + def init_algorithm_params(self): + '''Initialize other algorithm parameters''' + raise NotImplementedError
+ +
[docs] @abstractmethod + @lab_api + def init_nets(self, global_nets=None): + '''Initialize the neural network from the spec''' + raise NotImplementedError
+ +
[docs] @lab_api + def post_init_nets(self): + ''' + Method to conditionally load models. + Call at the end of init_nets() after setting self.net_names + ''' + assert hasattr(self, 'net_names') + if util.in_eval_lab_modes(): + logger.info(f'Loaded algorithm models for lab_mode: {util.get_lab_mode()}') + self.load() + else: + logger.info(f'Initialized algorithm models for lab_mode: {util.get_lab_mode()}')
+ +
[docs] @lab_api + def calc_pdparam(self, x, evaluate=True, net=None): + ''' + To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs. + The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist. + ''' + raise NotImplementedError
+ +
[docs] def nanflat_to_data_a(self, data_name, nanflat_data_a): + '''Reshape nanflat_data_a, e.g. action_a, from a single pass back into the API-conforming data_a''' + data_names = (data_name,) + data_a, = self.agent.agent_space.aeb_space.init_data_s(data_names, a=self.agent.a) + for body, data in zip(self.agent.nanflat_body_a, nanflat_data_a): + e, b = body.e, body.b + data_a[(e, b)] = data + return data_a
+ +
[docs] @lab_api + def act(self, state): + '''Standard act method.''' + raise NotImplementedError
+ +
[docs] @abstractmethod + @lab_api + def sample(self): + '''Samples a batch from memory''' + raise NotImplementedError
+ +
[docs] @abstractmethod + @lab_api + def train(self): + '''Implement algorithm train, or throw NotImplementedError''' + if util.in_eval_lab_modes(): + return np.nan + raise NotImplementedError
+ +
[docs] @abstractmethod + @lab_api + def update(self): + '''Implement algorithm update, or throw NotImplementedError''' + raise NotImplementedError
+ +
[docs] @lab_api + def save(self, ckpt=None): + '''Save net models for algorithm given the required property self.net_names''' + if not hasattr(self, 'net_names'): + logger.info('No net declared in self.net_names in init_nets(); no models to save.') + else: + net_util.save_algorithm(self, ckpt=ckpt)
+ +
[docs] @lab_api + def load(self): + '''Load net models for algorithm given the required property self.net_names''' + if not hasattr(self, 'net_names'): + logger.info('No net declared in self.net_names in init_nets(); no models to load.') + else: + net_util.load_algorithm(self) + # set decayable variables to final values + for k, v in vars(self).items(): + if k.endswith('_scheduler'): + var_name = k.replace('_scheduler', '') + if hasattr(v, 'end_val'): + setattr(self.body, var_name, v.end_val)
+ + # NOTE optional extension for multi-agent-env + +
[docs] @lab_api + def space_act(self, state_a): + '''Interface-level agent act method for all its bodies. Resolves state to state; get action and compose into action.''' + data_names = ('action',) + action_a, = self.agent.agent_space.aeb_space.init_data_s(data_names, a=self.agent.a) + for eb, body in util.ndenumerate_nonan(self.agent.body_a): + state = state_a[eb] + self.body = body + action_a[eb] = self.act(state) + # set body reference back to default + self.body = self.agent.nanflat_body_a[0] + return action_a
+ +
[docs] @lab_api + def space_sample(self): + '''Samples a batch from memory''' + batches = [] + for body in self.agent.nanflat_body_a: + self.body = body + batches.append(self.sample()) + # set body reference back to default + self.body = self.agent.nanflat_body_a[0] + batch = util.concat_batches(batches) + batch = util.to_torch_batch(batch, self.net.device, self.body.memory.is_episodic) + return batch
+ +
[docs] @lab_api + def space_train(self): + if util.in_eval_lab_modes(): + return np.nan + losses = [] + for body in self.agent.nanflat_body_a: + self.body = body + losses.append(self.train()) + # set body reference back to default + self.body = self.agent.nanflat_body_a[0] + loss_a = self.nanflat_to_data_a('loss', losses) + return loss_a
+ +
[docs] @lab_api + def space_update(self): + explore_vars = [] + for body in self.agent.nanflat_body_a: + self.body = body + explore_vars.append(self.update()) + # set body reference back to default + self.body = self.agent.nanflat_body_a[0] + explore_var_a = self.nanflat_to_data_a('explore_var', explore_vars) + return explore_var_a
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/dqn.html b/docs/build/html/_modules/convlab/agent/algorithm/dqn.html new file mode 100644 index 0000000..daa37e5 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/dqn.html @@ -0,0 +1,547 @@ + + + + + + + + + + + convlab.agent.algorithm.dqn — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.dqn

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+import torch
+
+from convlab.agent import memory
+from convlab.agent import net
+from convlab.agent.algorithm.sarsa import SARSA
+from convlab.agent.net import net_util
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class VanillaDQN(SARSA): + ''' + Implementation of a simple DQN algorithm. + Algorithm: + 1. Collect some examples by acting in the environment and store them in a replay memory + 2. Every K steps sample N examples from replay memory + 3. For each example calculate the target (bootstrapped estimate of the discounted value of the state and action taken), y, using a neural network to approximate the Q function. s' is the next state following the action actually taken. + y_t = r_t + gamma * argmax_a Q(s_t', a) + 4. For each example calculate the current estimate of the discounted value of the state and action taken + x_t = Q(s_t, a_t) + 5. Calculate L(x, y) where L is a regression loss (eg. mse) + 6. Calculate the gradient of L with respect to all the parameters in the network and update the network parameters using the gradient + 7. Repeat steps 3 - 6 M times + 8. Repeat steps 2 - 7 Z times + 9. Repeat steps 1 - 8 + + For more information on Q-Learning see Sergey Levine's lectures 6 and 7 from CS294-112 Fall 2017 + https://www.youtube.com/playlist?list=PLkFD6_40KJIznC9CDbVTjAF2oyt8_VAe3 + + e.g. algorithm_spec + "algorithm": { + "name": "VanillaDQN", + "action_pdtype": "Argmax", + "action_policy": "epsilon_greedy", + "explore_var_spec": { + "name": "linear_decay", + "start_val": 1.0, + "end_val": 0.1, + "start_step": 10, + "end_step": 1000, + }, + "gamma": 0.99, + "training_batch_iter": 8, + "training_iter": 4, + "training_frequency": 10, + "training_start_step": 10, + } + ''' + +
[docs] @lab_api + def init_algorithm_params(self): + # set default + util.set_attr(self, dict( + action_pdtype='Argmax', + action_policy='epsilon_greedy', + explore_var_spec=None, + )) + util.set_attr(self, self.algorithm_spec, [ + 'action_pdtype', + 'action_policy', + # explore_var is epsilon, tau or etc. depending on the action policy + # these control the trade off between exploration and exploitaton + 'explore_var_spec', + 'gamma', # the discount factor + 'training_batch_iter', # how many gradient updates per batch + 'training_iter', # how many batches to train each time + 'training_frequency', # how often to train (once a few timesteps) + 'training_start_step', # how long before starting training + ]) + super().init_algorithm_params()
+ +
[docs] @lab_api + def init_nets(self, global_nets=None): + '''Initialize the neural network used to learn the Q function from the spec''' + if self.algorithm_spec['name'] == 'VanillaDQN': + assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for VanillaDQN; use DQN.' + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dim, out_dim) + self.net_names = ['net'] + # init net optimizer and its lr scheduler + self.optim = net_util.get_optim(self.net, self.net.optim_spec) + self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) + self.post_init_nets()
+ +
[docs] def calc_q_loss(self, batch): + '''Compute the Q value loss using predicted and target Q values from the appropriate networks''' + states = batch['states'] + next_states = batch['next_states'] + q_preds = self.net(states) + with torch.no_grad(): + next_q_preds = self.net(next_states) + act_q_preds = q_preds.gather(-1, batch['actions'].long().unsqueeze(-1)).squeeze(-1) + # Bellman equation: compute max_q_targets using reward and max estimated Q values (0 if no next_state) + max_next_q_preds, _ = next_q_preds.max(dim=-1, keepdim=True) + max_q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * max_next_q_preds + logger.debug(f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}') + q_loss = self.net.loss_fn(act_q_preds, max_q_targets) + + # TODO use the same loss_fn but do not reduce yet + if 'Prioritized' in util.get_class_name(self.body.memory): # PER + errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy() + self.body.memory.update_priorities(errors) + return q_loss
+ +
[docs] @lab_api + def act(self, state): + '''Selects and returns a discrete action for body using the action policy''' + return super().act(state)
+ +
[docs] @lab_api + def sample(self): + '''Samples a batch from memory of size self.memory_spec['batch_size']''' + batch = self.body.memory.sample() + batch = util.to_torch_batch(batch, self.net.device, self.body.memory.is_episodic) + return batch
+ +
[docs] @lab_api + def train(self): + ''' + Completes one training step for the agent if it is time to train. + i.e. the environment timestep is greater than the minimum training timestep and a multiple of the training_frequency. + Each training step consists of sampling n batches from the agent's memory. + For each of the batches, the target Q values (q_targets) are computed and a single training step is taken k times + Otherwise this function does nothing. + ''' + if util.in_eval_lab_modes(): + return np.nan + clock = self.body.env.clock + if self.to_train == 1: + total_loss = torch.tensor(0.0) + for _ in range(self.training_iter): + batch = self.sample() + clock.set_batch_size(len(batch)) + for _ in range(self.training_batch_iter): + loss = self.calc_q_loss(batch) + self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + total_loss += loss + loss = total_loss / (self.training_iter * self.training_batch_iter) + # reset + self.to_train = 0 + logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}') + return loss.item() + else: + return np.nan
+ +
[docs] @lab_api + def update(self): + '''Update the agent after training''' + return super().update()
+ + +
[docs]class DQNBase(VanillaDQN): + ''' + Implementation of the base DQN algorithm. + The algorithm follows the same general approach as VanillaDQN but is more general since it allows + for two different networks (through self.net and self.target_net). + + self.net is used to act, and is the network trained. + self.target_net is used to estimate the maximum value of the Q-function in the next state when calculating the target (see VanillaDQN comments). + self.target_net is updated periodically to either match self.net (self.net.update_type = "replace") or to be a weighted average of self.net and the previous self.target_net (self.net.update_type = "polyak") + If desired, self.target_net can be updated slowly, and this can help to stabilize learning. + + It also allows for different nets to be used to select the action in the next state and to evaluate the value of that action through self.online_net and self.eval_net. This can help reduce the tendency of DQN's to overestimate the value of the Q-function. Following this approach leads to the DoubleDQN algorithm. + + Setting all nets to self.net reduces to the VanillaDQN case. + ''' + +
[docs] @lab_api + def init_nets(self, global_nets=None): + '''Initialize networks''' + if self.algorithm_spec['name'] == 'DQNBase': + assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for DQNBase; use DQN.' + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dim, out_dim) + self.target_net = NetClass(self.net_spec, in_dim, out_dim) + self.net_names = ['net', 'target_net'] + # init net optimizer and its lr scheduler + self.optim = net_util.get_optim(self.net, self.net.optim_spec) + self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) + self.post_init_nets() + self.online_net = self.target_net + self.eval_net = self.target_net
+ +
[docs] def calc_q_loss(self, batch): + '''Compute the Q value loss using predicted and target Q values from the appropriate networks''' + states = batch['states'] + next_states = batch['next_states'] + q_preds = self.net(states) + with torch.no_grad(): + # Use online_net to select actions in next state + online_next_q_preds = self.online_net(next_states) + # Use eval_net to calculate next_q_preds for actions chosen by online_net + next_q_preds = self.eval_net(next_states) + act_q_preds = q_preds.gather(-1, batch['actions'].long().unsqueeze(-1)).squeeze(-1) + online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True) + max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1) + max_q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * max_next_q_preds + logger.debug(f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}') + q_loss = self.net.loss_fn(act_q_preds, max_q_targets) + + # TODO use the same loss_fn but do not reduce yet + if 'Prioritized' in util.get_class_name(self.body.memory): # PER + errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy() + self.body.memory.update_priorities(errors) + return q_loss
+ +
[docs] def update_nets(self): + if util.frame_mod(self.body.env.clock.frame, self.net.update_frequency, self.body.env.num_envs): + if self.net.update_type == 'replace': + net_util.copy(self.net, self.target_net) + elif self.net.update_type == 'polyak': + net_util.polyak_update(self.net, self.target_net, self.net.polyak_coef) + else: + raise ValueError('Unknown net.update_type. Should be "replace" or "polyak". Exiting.')
+ +
[docs] @lab_api + def update(self): + '''Updates self.target_net and the explore variables''' + self.update_nets() + return super().update()
+ + +
[docs]class DQN(DQNBase): + ''' + DQN class + + e.g. algorithm_spec + "algorithm": { + "name": "DQN", + "action_pdtype": "Argmax", + "action_policy": "epsilon_greedy", + "explore_var_spec": { + "name": "linear_decay", + "start_val": 1.0, + "end_val": 0.1, + "start_step": 10, + "end_step": 1000, + }, + "gamma": 0.99, + "training_batch_iter": 8, + "training_iter": 4, + "training_frequency": 10, + "training_start_step": 10 + } + ''' +
[docs] @lab_api + def init_nets(self, global_nets=None): + super().init_nets(global_nets)
+ + +
[docs]class WarmUpDQN(DQN): + ''' + DQN class + + e.g. algorithm_spec + "algorithm": { + "name": "WarmUpDQN", + "action_pdtype": "Argmax", + "action_policy": "epsilon_greedy", + "warmup_epi": 300, + "explore_var_spec": { + "name": "linear_decay", + "start_val": 1.0, + "end_val": 0.1, + "start_step": 10, + "end_step": 1000, + }, + "gamma": 0.99, + "training_batch_iter": 8, + "training_iter": 4, + "training_frequency": 10, + "training_start_step": 10 + } + ''' + def __init__(self, agent, global_nets=None): + super().__init__(agent, global_nets) + util.set_attr(self, self.algorithm_spec, [ + 'warmup_epi', + ]) + # create the extra replay memory for warm-up + MemoryClass = getattr(memory, self.memory_spec['warmup_name']) + self.body.warmup_memory = MemoryClass(self.memory_spec, self.body) + +
[docs] @lab_api + def init_nets(self, global_nets=None): + super().init_nets(global_nets)
+ +
[docs] def warmup_sample(self): + '''Samples a batch from warm-up memory''' + batch = self.body.warmup_memory.sample() + batch = util.to_torch_batch(batch, self.net.device, self.body.warmup_memory.is_episodic) + return batch
+ +
[docs] def train(self): + if util.in_eval_lab_modes(): + return np.nan + clock = self.body.env.clock + if self.to_train == 1: + total_loss = torch.tensor(0.0) + for _ in range(self.training_iter): + batches = [] + if self.body.warmup_memory.size >= self.body.warmup_memory.batch_size: + batches.append(self.warmup_sample()) + if self.body.memory.size >= self.body.memory.batch_size: + batches.append(self.sample()) + clock.set_batch_size(sum(len(batch) for batch in batches)) + for batch in batches: + for _ in range(self.training_batch_iter): + loss = self.calc_q_loss(batch) + self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + total_loss += loss + loss = total_loss / (self.training_iter * self.training_batch_iter) + # reset + self.to_train = 0 + logger.info(f'Trained {self.name} at epi: {clock.epi}, warmup_size: {self.body.warmup_memory.size}, memory_size: {self.body.memory.size}, loss: {loss:g}') + return loss.item() + else: + return np.nan
+ + +
[docs]class DoubleDQN(DQN): + ''' + Double-DQN (DDQN) class + + e.g. algorithm_spec + "algorithm": { + "name": "DDQN", + "action_pdtype": "Argmax", + "action_policy": "epsilon_greedy", + "explore_var_spec": { + "name": "linear_decay", + "start_val": 1.0, + "end_val": 0.1, + "start_step": 10, + "end_step": 1000, + }, + "gamma": 0.99, + "training_batch_iter": 8, + "training_iter": 4, + "training_frequency": 10, + "training_start_step": 10 + } + ''' +
[docs] @lab_api + def init_nets(self, global_nets=None): + super().init_nets(global_nets) + self.online_net = self.net + self.eval_net = self.target_net
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/external.html b/docs/build/html/_modules/convlab/agent/algorithm/external.html new file mode 100644 index 0000000..99b40bc --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/external.html @@ -0,0 +1,261 @@ + + + + + + + + + + + convlab.agent.algorithm.external — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.external

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+'''
+The random agent algorithm
+For basic dev purpose.
+'''
+from copy import deepcopy
+
+import pydash as ps
+
+from convlab.agent.algorithm import policy_util
+from convlab.agent.algorithm.base import Algorithm
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+from convlab.modules import policy, word_policy, e2e
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class ExternalPolicy(Algorithm): + ''' + Example Random agent that works in both discrete and continuous envs + ''' + +
[docs] @lab_api + def init_algorithm_params(self): + '''Initialize other algorithm parameters''' + # set default + util.set_attr(self, dict( + action_pdtype='default', + action_policy='default', + )) + util.set_attr(self, self.algorithm_spec, [ + 'policy_name', + 'action_pdtype', + 'action_policy', + ]) + self.action_policy = getattr(policy_util, self.action_policy) + self.policy = None + if 'word_policy' in self.algorithm_spec: + params = deepcopy(ps.get(self.algorithm_spec, 'word_policy')) + PolicyClass = getattr(word_policy, params.pop('name')) + elif 'e2e' in self.algorithm_spec: + params = deepcopy(ps.get(self.algorithm_spec, 'e2e')) + PolicyClass = getattr(e2e, params.pop('name')) + else: + params = deepcopy(ps.get(self.algorithm_spec, 'policy')) + PolicyClass = getattr(policy, params.pop('name')) + self.policy = PolicyClass(**params)
+ +
[docs] def reset(self): + self.policy.init_session()
+ +
[docs] @lab_api + def init_nets(self, global_nets=None): + pass
+ +
[docs] @lab_api + def act(self, state): + action = self.policy.predict(state) + return action
+ +
[docs] @lab_api + def sample(self): + pass
+ +
[docs] @lab_api + def train(self): + pass
+ +
[docs] @lab_api + def update(self): + pass
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/policy_util.html b/docs/build/html/_modules/convlab/agent/algorithm/policy_util.html new file mode 100644 index 0000000..c2f6f15 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/policy_util.html @@ -0,0 +1,504 @@ + + + + + + + + + + + convlab.agent.algorithm.policy_util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.policy_util

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+# Action policy module
+# Constructs action probability distribution used by agent to sample action and calculate log_prob, entropy, etc.
+from gym import spaces
+from torch import distributions
+
+# from convlab.env.wrapper import LazyFrames
+from convlab.lib import distribution, logger, math_util, util
+
+logger = logger.get_logger(__name__)
+
+# register custom distributions
+setattr(distributions, 'Argmax', distribution.Argmax)
+setattr(distributions, 'GumbelCategorical', distribution.GumbelCategorical)
+setattr(distributions, 'MultiCategorical', distribution.MultiCategorical)
+# probability distributions constraints for different action types; the first in the list is the default
+ACTION_PDS = {
+    'continuous': ['Normal', 'Beta', 'Gumbel', 'LogNormal'],
+    'multi_continuous': ['MultivariateNormal'],
+    'discrete': ['Categorical', 'Argmax', 'GumbelCategorical'],
+    'multi_discrete': ['MultiCategorical'],
+    'multi_binary': ['Bernoulli'],
+}
+
+
+
[docs]def get_action_type(action_space): + '''Method to get the action type to choose prob. dist. to sample actions from NN logits output''' + if isinstance(action_space, spaces.Box): + shape = action_space.shape + assert len(shape) == 1 + if shape[0] == 1: + return 'continuous' + else: + return 'multi_continuous' + elif isinstance(action_space, spaces.Discrete): + return 'discrete' + elif isinstance(action_space, spaces.MultiDiscrete): + return 'multi_discrete' + elif isinstance(action_space, spaces.MultiBinary): + return 'multi_binary' + else: + raise NotImplementedError
+ + +# action_policy base methods + +
[docs]def get_action_pd_cls(action_pdtype, action_type): + ''' + Verify and get the action prob. distribution class for construction + Called by body at init to set its own ActionPD + ''' + pdtypes = ACTION_PDS[action_type] + assert action_pdtype in pdtypes, f'Pdtype {action_pdtype} is not compatible/supported with action_type {action_type}. Options are: {pdtypes}' + ActionPD = getattr(distributions, action_pdtype) + return ActionPD
+ + +
[docs]def guard_tensor(state, body): + '''Guard-cast tensor before being input to network''' + # if isinstance(state, LazyFrames): + # state = state.__array__() # realize data + state = torch.from_numpy(state.astype(np.float32)) + if not body.env.is_venv or util.in_eval_lab_modes(): + # singleton state, unsqueeze as minibatch for net input + state = state.unsqueeze(dim=0) + return state
+ + +
[docs]def calc_pdparam(state, algorithm, body): + ''' + Prepare the state and run algorithm.calc_pdparam to get pdparam for action_pd + @param tensor:state For pdparam = net(state) + @param algorithm The algorithm containing self.net + @param body Body which links algorithm to the env which the action is for + @returns tensor:pdparam + @example + + pdparam = calc_pdparam(state, algorithm, body) + action_pd = ActionPD(logits=pdparam) # e.g. ActionPD is Categorical + action = action_pd.sample() + ''' + if not torch.is_tensor(state): # dont need to cast from numpy + state = guard_tensor(state, body) + state = state.to(algorithm.net.device) + pdparam = algorithm.calc_pdparam(state) + return pdparam
+ + +
[docs]def init_action_pd(ActionPD, pdparam): + ''' + Initialize the action_pd for discrete or continuous actions: + - discrete: action_pd = ActionPD(logits) + - continuous: action_pd = ActionPD(loc, scale) + ''' + if 'logits' in ActionPD.arg_constraints: # discrete + action_pd = ActionPD(logits=pdparam) + else: # continuous, args = loc and scale + if isinstance(pdparam, list): # split output + loc, scale = pdparam + else: + loc, scale = pdparam.transpose(0, 1) + # scale (stdev) must be > 0, use softplus with positive + scale = F.softplus(scale) + 1e-8 + if isinstance(pdparam, list): # split output + # construct covars from a batched scale tensor + covars = torch.diag_embed(scale) + action_pd = ActionPD(loc=loc, covariance_matrix=covars) + else: + action_pd = ActionPD(loc=loc, scale=scale) + return action_pd
+ + +
[docs]def sample_action(ActionPD, pdparam): + ''' + Convenience method to sample action(s) from action_pd = ActionPD(pdparam) + Works with batched pdparam too + @returns tensor:action Sampled action(s) + @example + + # policy contains: + pdparam = calc_pdparam(state, algorithm, body) + action = sample_action(body.ActionPD, pdparam) + ''' + action_pd = init_action_pd(ActionPD, pdparam) + action = action_pd.sample() + return action
+ + +# action_policy used by agent + + +
[docs]def default(state, algorithm, body): + '''Plain policy by direct sampling from a default action probability defined by body.ActionPD''' + pdparam = calc_pdparam(state, algorithm, body) + action = sample_action(body.ActionPD, pdparam) + return action
+ + +
[docs]def random(state, algorithm, body): + '''Random action using gym.action_space.sample(), with the same format as default()''' + if body.env.is_venv and not util.in_eval_lab_modes(): + _action = [body.action_space.sample() for _ in range(body.env.num_envs)] + else: + _action = body.action_space.sample() + action = torch.tensor([_action]) + return action
+ + +
[docs]def epsilon_greedy(state, algorithm, body): + '''Epsilon-greedy policy: with probability epsilon, do random action, otherwise do default sampling.''' + epsilon = body.explore_var + if epsilon > np.random.rand(): + return random(state, algorithm, body) + else: + return default(state, algorithm, body)
+ + +
[docs]def boltzmann(state, algorithm, body): + ''' + Boltzmann policy: adjust pdparam with temperature tau; the higher the more randomness/noise in action. + ''' + tau = body.explore_var + pdparam = calc_pdparam(state, algorithm, body) + pdparam /= tau + action = sample_action(body.ActionPD, pdparam) + return action
+ + +
[docs]def warmup_epsilon_greedy(state, algorithm, body): + action = default(state, algorithm, body) + + if util.in_eval_lab_modes(): + return action + + epsilon = body.explore_var + if epsilon > np.random.rand(): + action = random(state, algorithm, body) + if body.env.clock.epi < algorithm.warmup_epi: + if hasattr(body, 'state'): + action = rule_guide(body.state, algorithm, body) + else: + action = rule_guide(state, algorithm, body) + return action
+ + +
[docs]def warmup_default(state, algorithm, body): + action = default(state, algorithm, body) + + if util.in_eval_lab_modes(): + return action + + if body.env.clock.epi < algorithm.warmup_epi: + if hasattr(body, 'state'): + action = rule_guide(body.state, algorithm, body) + else: + action = rule_guide(state, algorithm, body) + return action
+ + +
[docs]def rule_guide(state, algorithm, body): + env = body.env.u_env + action = env.rule_policy(state, algorithm, body) + probs = torch.zeros(body.action_space.high, device=algorithm.net.device) + probs[action] = 1 + action = torch.tensor(action, device=algorithm.net.device) + return action
+ +# multi-body/multi-env action_policy used by agent +# TODO rework + +
[docs]def multi_default(states, algorithm, body_list, pdparam): + ''' + Apply default policy body-wise + Note, for efficiency, do a single forward pass to calculate pdparam, then call this policy like: + @example + + pdparam = self.calc_pdparam(state) + action_a = self.action_policy(pdparam, self, body_list) + ''' + # assert pdparam has been chunked + assert len(pdparam.shape) > 1 and len(pdparam) == len(body_list), f'pdparam shape: {pdparam.shape}, bodies: {len(body_list)}' + action_list = [] + for idx, sub_pdparam in enumerate(pdparam): + body = body_list[idx] + guard_tensor(states[idx], body) # for consistency with singleton inner logic + action = sample_action(body.ActionPD, sub_pdparam) + action_list.append(action) + action_a = torch.tensor(action_list, device=algorithm.net.device).unsqueeze(dim=1) + return action_a
+ + +
[docs]def multi_random(states, algorithm, body_list, pdparam): + '''Apply random policy body-wise.''' + action_list = [] + for idx, body in body_list: + action = random(states[idx], algorithm, body) + action_list.append(action) + action_a = torch.tensor(action_list, device=algorithm.net.device).unsqueeze(dim=1) + return action_a
+ + +
[docs]def multi_epsilon_greedy(states, algorithm, body_list, pdparam): + '''Apply epsilon-greedy policy body-wise''' + assert len(pdparam) > 1 and len(pdparam) == len(body_list), f'pdparam shape: {pdparam.shape}, bodies: {len(body_list)}' + action_list = [] + for idx, sub_pdparam in enumerate(pdparam): + body = body_list[idx] + epsilon = body.explore_var + if epsilon > np.random.rand(): + action = random(states[idx], algorithm, body) + else: + guard_tensor(states[idx], body) # for consistency with singleton inner logic + action = sample_action(body.ActionPD, sub_pdparam) + action_list.append(action) + action_a = torch.tensor(action_list, device=algorithm.net.device).unsqueeze(dim=1) + return action_a
+ + +
[docs]def multi_boltzmann(states, algorithm, body_list, pdparam): + '''Apply Boltzmann policy body-wise''' + assert len(pdparam) > 1 and len(pdparam) == len(body_list), f'pdparam shape: {pdparam.shape}, bodies: {len(body_list)}' + action_list = [] + for idx, sub_pdparam in enumerate(pdparam): + body = body_list[idx] + guard_tensor(states[idx], body) # for consistency with singleton inner logic + tau = body.explore_var + sub_pdparam /= tau + action = sample_action(body.ActionPD, sub_pdparam) + action_list.append(action) + action_a = torch.tensor(action_list, device=algorithm.net.device).unsqueeze(dim=1) + return action_a
+ + +# action policy update methods + +
[docs]class VarScheduler: + ''' + Variable scheduler for decaying variables such as explore_var (epsilon, tau) and entropy + + e.g. spec + "explore_var_spec": { + "name": "linear_decay", + "start_val": 1.0, + "end_val": 0.1, + "start_step": 0, + "end_step": 800, + }, + ''' + + def __init__(self, var_decay_spec=None): + self._updater_name = 'no_decay' if var_decay_spec is None else var_decay_spec['name'] + self._updater = getattr(math_util, self._updater_name) + util.set_attr(self, dict( + start_val=np.nan, + )) + util.set_attr(self, var_decay_spec, [ + 'start_val', + 'end_val', + 'start_step', + 'end_step', + ]) + if not getattr(self, 'end_val', None): + self.end_val = self.start_val + +
[docs] def update(self, algorithm, clock): + '''Get an updated value for var''' + if (util.in_eval_lab_modes()) or self._updater_name == 'no_decay': + return self.end_val + step = clock.get() + val = self._updater(self.start_val, self.end_val, self.start_step, self.end_step, step) + return val
+ +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/ppo.html b/docs/build/html/_modules/convlab/agent/algorithm/ppo.html new file mode 100644 index 0000000..fea9a37 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/ppo.html @@ -0,0 +1,401 @@ + + + + + + + + + + + convlab.agent.algorithm.ppo — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.ppo

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+from copy import deepcopy
+
+import numpy as np
+import torch
+
+from convlab.agent.algorithm import policy_util
+from convlab.agent.algorithm.actor_critic import ActorCritic
+from convlab.agent.net import net_util
+from convlab.lib import logger, math_util, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class PPO(ActorCritic): + ''' + Implementation of PPO + This is actually just ActorCritic with a custom loss function + Original paper: "Proximal Policy Optimization Algorithms" + https://arxiv.org/pdf/1707.06347.pdf + + Adapted from OpenAI baselines, CPU version https://github.com/openai/baselines/tree/master/baselines/ppo1 + Algorithm: + for iteration = 1, 2, 3, ... do + for actor = 1, 2, 3, ..., N do + run policy pi_old in env for T timesteps + compute advantage A_1, ..., A_T + end for + optimize surrogate L wrt theta, with K epochs and minibatch size M <= NT + end for + + e.g. algorithm_spec + "algorithm": { + "name": "PPO", + "action_pdtype": "default", + "action_policy": "default", + "explore_var_spec": null, + "gamma": 0.99, + "lam": 1.0, + "clip_eps_spec": { + "name": "linear_decay", + "start_val": 0.01, + "end_val": 0.001, + "start_step": 100, + "end_step": 5000, + }, + "entropy_coef_spec": { + "name": "linear_decay", + "start_val": 0.01, + "end_val": 0.001, + "start_step": 100, + "end_step": 5000, + }, + "minibatch_size": 256, + "training_frequency": 1, + "training_epoch": 8, + } + + e.g. special net_spec param "shared" to share/separate Actor/Critic + "net": { + "type": "MLPNet", + "shared": true, + ... + ''' + +
[docs] @lab_api + def init_algorithm_params(self): + '''Initialize other algorithm parameters''' + # set default + util.set_attr(self, dict( + action_pdtype='default', + action_policy='default', + explore_var_spec=None, + entropy_coef_spec=None, + minibatch_size=4, + val_loss_coef=1.0, + )) + util.set_attr(self, self.algorithm_spec, [ + 'action_pdtype', + 'action_policy', + # theoretically, PPO does not have policy update; but in this implementation we have such option + 'explore_var_spec', + 'gamma', + 'lam', + 'clip_eps_spec', + 'entropy_coef_spec', + 'val_loss_coef', + 'minibatch_size', + 'training_frequency', # horizon + 'training_epoch', + ]) + self.to_train = 0 + self.action_policy = getattr(policy_util, self.action_policy) + self.explore_var_scheduler = policy_util.VarScheduler(self.explore_var_spec) + self.body.explore_var = self.explore_var_scheduler.start_val + # extra variable decays for PPO + self.clip_eps_scheduler = policy_util.VarScheduler(self.clip_eps_spec) + self.body.clip_eps = self.clip_eps_scheduler.start_val + if self.entropy_coef_spec is not None: + self.entropy_coef_scheduler = policy_util.VarScheduler(self.entropy_coef_spec) + self.body.entropy_coef = self.entropy_coef_scheduler.start_val + # PPO uses GAE + self.calc_advs_v_targets = self.calc_gae_advs_v_targets
+ +
[docs] @lab_api + def init_nets(self, global_nets=None): + '''PPO uses old and new to calculate ratio for loss''' + super().init_nets(global_nets) + # create old net to calculate ratio + self.old_net = deepcopy(self.net) + assert id(self.old_net) != id(self.net)
+ +
[docs] def calc_policy_loss(self, batch, pdparams, advs): + ''' + The PPO loss function (subscript t is omitted) + L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ] + + Breakdown piecewise, + 1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ] + where ratio = pi(a|s) / pi_old(a|s) + + 2. L^VF = E[ mse(V(s_t), V^target) ] + + 3. S = E[ entropy ] + ''' + clip_eps = self.body.clip_eps + action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) + states = batch['states'] + actions = batch['actions'] + if self.body.env.is_venv: + states = math_util.venv_unpack(states) + actions = math_util.venv_unpack(actions) + + # L^CLIP + log_probs = action_pd.log_prob(actions) + with torch.no_grad(): + old_pdparams = self.calc_pdparam(states, net=self.old_net) + old_action_pd = policy_util.init_action_pd(self.body.ActionPD, old_pdparams) + old_log_probs = old_action_pd.log_prob(actions) + assert log_probs.shape == old_log_probs.shape + ratios = torch.exp(log_probs - old_log_probs) # clip to prevent overflow + logger.debug(f'ratios: {ratios}') + sur_1 = ratios * advs + sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs + # flip sign because need to maximize + clip_loss = -torch.min(sur_1, sur_2).mean() + logger.debug(f'clip_loss: {clip_loss}') + + # L^VF (inherit from ActorCritic) + + # S entropy bonus + entropy = action_pd.entropy().mean() + self.body.mean_entropy = entropy # update logging variable + ent_penalty = -self.body.entropy_coef * entropy + logger.debug(f'ent_penalty: {ent_penalty}') + + policy_loss = clip_loss + ent_penalty + logger.debug(f'PPO Actor policy loss: {policy_loss:g}') + return policy_loss
+ +
[docs] def train(self): + if util.in_eval_lab_modes(): + return np.nan + clock = self.body.env.clock + if self.to_train == 1: + net_util.copy(self.net, self.old_net) # update old net + batch = self.sample() + clock.set_batch_size(len(batch)) + _pdparams, v_preds = self.calc_pdparam_v(batch) + advs, v_targets = self.calc_advs_v_targets(batch, v_preds) + # piggy back on batch, but remember to not pack or unpack + batch['advs'], batch['v_targets'] = advs, v_targets + if self.body.env.is_venv: # unpack if venv for minibatch sampling + for k, v in batch.items(): + if k not in ('advs', 'v_targets'): + batch[k] = math_util.venv_unpack(v) + total_loss = torch.tensor(0.0) + for _ in range(self.training_epoch): + minibatches = util.split_minibatch(batch, self.minibatch_size) + for minibatch in minibatches: + if self.body.env.is_venv: # re-pack to restore proper shape + for k, v in minibatch.items(): + if k not in ('advs', 'v_targets'): + minibatch[k] = math_util.venv_pack(v, self.body.env.num_envs) + advs, v_targets = minibatch['advs'], minibatch['v_targets'] + pdparams, v_preds = self.calc_pdparam_v(minibatch) + policy_loss = self.calc_policy_loss(minibatch, pdparams, advs) # from actor + val_loss = self.calc_val_loss(v_preds, v_targets) # from critic + if self.shared: # shared network + loss = policy_loss + val_loss + self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + else: + self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) + loss = policy_loss + val_loss + total_loss += loss + loss = total_loss / self.training_epoch / len(minibatches) + # reset + self.to_train = 0 + logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}') + return loss.item() + else: + return np.nan
+ +
[docs] @lab_api + def update(self): + self.body.explore_var = self.explore_var_scheduler.update(self, self.body.env.clock) + if self.entropy_coef_spec is not None: + self.body.entropy_coef = self.entropy_coef_scheduler.update(self, self.body.env.clock) + self.body.clip_eps = self.clip_eps_scheduler.update(self, self.body.env.clock) + return self.body.explore_var
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/random.html b/docs/build/html/_modules/convlab/agent/algorithm/random.html new file mode 100644 index 0000000..7de2602 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/random.html @@ -0,0 +1,245 @@ + + + + + + + + + + + convlab.agent.algorithm.random — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.random

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+
+# The random agent algorithm
+# For basic dev purpose
+from convlab.agent.algorithm.base import Algorithm
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class Random(Algorithm): + ''' + Example Random agent that works in both discrete and continuous envs + ''' + +
[docs] @lab_api + def init_algorithm_params(self): + '''Initialize other algorithm parameters''' + self.to_train = 0 + self.training_frequency = 1 + self.training_start_step = 0
+ +
[docs] @lab_api + def init_nets(self, global_nets=None): + '''Initialize the neural network from the spec''' + self.net_names = []
+ +
[docs] @lab_api + def act(self, state): + '''Random action''' + body = self.body + if body.env.is_venv and not util.in_eval_lab_modes(): + action = np.array([body.action_space.sample() for _ in range(body.env.num_envs)]) + else: + action = body.action_space.sample() + return action
+ +
[docs] @lab_api + def sample(self): + self.body.memory.sample() + batch = np.nan + return batch
+ +
[docs] @lab_api + def train(self): + self.sample() + self.body.env.clock.tick('opt_step') # to simulate metrics calc + loss = np.nan + return loss
+ +
[docs] @lab_api + def update(self): + self.body.explore_var = np.nan + return self.body.explore_var
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/reinforce.html b/docs/build/html/_modules/convlab/agent/algorithm/reinforce.html new file mode 100644 index 0000000..9349450 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/reinforce.html @@ -0,0 +1,398 @@ + + + + + + + + + + + convlab.agent.algorithm.reinforce — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.reinforce

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+
+from convlab.agent import net
+from convlab.agent.algorithm import policy_util
+from convlab.agent.algorithm.base import Algorithm
+from convlab.agent.net import net_util
+from convlab.lib import logger, math_util, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class Reinforce(Algorithm): + ''' + Implementation of REINFORCE (Williams, 1992) with baseline for discrete or continuous actions http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf + Adapted from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py + Algorithm: + 0. Collect n episodes of data + 1. At each timestep in an episode + - Calculate the advantage of that timestep + - Multiply the advantage by the negative of the log probability of the action taken + 2. Sum all the values above. + 3. Calculate the gradient of this value with respect to all of the parameters of the network + 4. Update the network parameters using the gradient + + e.g. algorithm_spec: + "algorithm": { + "name": "Reinforce", + "action_pdtype": "default", + "action_policy": "default", + "explore_var_spec": null, + "gamma": 0.99, + "entropy_coef_spec": { + "name": "linear_decay", + "start_val": 0.01, + "end_val": 0.001, + "start_step": 100, + "end_step": 5000, + }, + "training_frequency": 1, + } + ''' + +
[docs] @lab_api + def init_algorithm_params(self): + '''Initialize other algorithm parameters''' + # set default + util.set_attr(self, dict( + action_pdtype='default', + action_policy='default', + explore_var_spec=None, + entropy_coef_spec=None, + policy_loss_coef=1.0, + )) + util.set_attr(self, self.algorithm_spec, [ + 'action_pdtype', + 'action_policy', + # theoretically, REINFORCE does not have policy update; but in this implementation we have such option + 'explore_var_spec', + 'gamma', # the discount factor + 'entropy_coef_spec', + 'policy_loss_coef', + 'training_frequency', + ]) + self.to_train = 0 + self.action_policy = getattr(policy_util, self.action_policy) + self.explore_var_scheduler = policy_util.VarScheduler(self.explore_var_spec) + self.body.explore_var = self.explore_var_scheduler.start_val + if self.entropy_coef_spec is not None: + self.entropy_coef_scheduler = policy_util.VarScheduler(self.entropy_coef_spec) + self.body.entropy_coef = self.entropy_coef_scheduler.start_val
+ +
[docs] @lab_api + def init_nets(self, global_nets=None): + ''' + Initialize the neural network used to learn the policy function from the spec + Below we automatically select an appropriate net for a discrete or continuous action space if the setting is of the form 'MLPNet'. Otherwise the correct type of network is assumed to be specified in the spec. + Networks for continuous action spaces have two heads and return two values, the first is a tensor containing the mean of the action policy, the second is a tensor containing the std deviation of the action policy. The distribution is assumed to be a Gaussian (Normal) distribution. + Networks for discrete action spaces have a single head and return the logits for a categorical probability distribution over the discrete actions + ''' + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dim, out_dim) + self.net_names = ['net'] + # init net optimizer and its lr scheduler + self.optim = net_util.get_optim(self.net, self.net.optim_spec) + self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) + self.post_init_nets()
+ +
[docs] @lab_api + def calc_pdparam(self, x, net=None): + ''' + The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist. + ''' + net = self.net if net is None else net + pdparam = net(x) + return pdparam
+ +
[docs] @lab_api + def act(self, state): + body = self.body + action = self.action_policy(state, self, body) + return action.cpu().squeeze().numpy() # squeeze to handle scalar
+ +
[docs] @lab_api + def sample(self): + '''Samples a batch from memory''' + batch = self.body.memory.sample() + batch = util.to_torch_batch(batch, self.net.device, self.body.memory.is_episodic) + return batch
+ +
[docs] def calc_pdparam_batch(self, batch): + '''Efficiently forward to get pdparam and by batch for loss computation''' + states = batch['states'] + if self.body.env.is_venv: + states = math_util.venv_unpack(states) + pdparam = self.calc_pdparam(states) + return pdparam
+ +
[docs] def calc_ret_advs(self, batch): + '''Calculate plain returns; which is generalized to advantage in ActorCritic''' + rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) + advs = rets + if self.body.env.is_venv: + advs = math_util.venv_unpack(advs) + logger.debug(f'advs: {advs}') + return advs
+ +
[docs] def calc_policy_loss(self, batch, pdparams, advs): + '''Calculate the actor's policy loss''' + action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) + actions = batch['actions'] + if self.body.env.is_venv: + actions = math_util.venv_unpack(actions) + log_probs = action_pd.log_prob(actions) + policy_loss = - self.policy_loss_coef * (log_probs * advs).mean() + if self.entropy_coef_spec: + entropy = action_pd.entropy().mean() + self.body.mean_entropy = entropy # update logging variable + policy_loss += (-self.body.entropy_coef * entropy) + logger.debug(f'Actor policy loss: {policy_loss:g}') + return policy_loss
+ +
[docs] @lab_api + def train(self): + if util.in_eval_lab_modes(): + return np.nan + clock = self.body.env.clock + if self.to_train == 1: + batch = self.sample() + clock.set_batch_size(len(batch)) + pdparams = self.calc_pdparam_batch(batch) + advs = self.calc_ret_advs(batch) + loss = self.calc_policy_loss(batch, pdparams, advs) + self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + # reset + self.to_train = 0 + logger.info(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}') + return loss.item() + else: + return np.nan
+ +
[docs] @lab_api + def update(self): + self.body.explore_var = self.explore_var_scheduler.update(self, self.body.env.clock) + if self.entropy_coef_spec is not None: + self.body.entropy_coef = self.entropy_coef_scheduler.update(self, self.body.env.clock) + return self.body.explore_var
+ + +
[docs]class WarmUpReinforce(Reinforce): + ''' + Implementation of REINFORCE (Williams, 1992) with baseline for discrete or continuous actions http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf + Adapted from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py + Algorithm: + 0. Collect n episodes of data + 1. At each timestep in an episode + - Calculate the advantage of that timestep + - Multiply the advantage by the negative of the log probability of the action taken + 2. Sum all the values above. + 3. Calculate the gradient of this value with respect to all of the parameters of the network + 4. Update the network parameters using the gradient + + e.g. algorithm_spec: + "algorithm": { + "name": "Reinforce", + "action_pdtype": "default", + "action_policy": "default", + "warmup_epi": 300, + "explore_var_spec": null, + "gamma": 0.99, + "entropy_coef_spec": { + "name": "linear_decay", + "start_val": 0.01, + "end_val": 0.001, + "start_step": 100, + "end_step": 5000, + }, + "training_frequency": 1, + } + ''' + def __init__(self, agent, global_nets=None): + super().__init__(agent, global_nets) + util.set_attr(self, self.algorithm_spec, [ + 'warmup_epi', + ])
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/sarsa.html b/docs/build/html/_modules/convlab/agent/algorithm/sarsa.html new file mode 100644 index 0000000..77eddea --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/sarsa.html @@ -0,0 +1,346 @@ + + + + + + + + + + + convlab.agent.algorithm.sarsa — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.sarsa

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+import torch
+
+from convlab.agent import net
+from convlab.agent.algorithm import policy_util
+from convlab.agent.algorithm.base import Algorithm
+from convlab.agent.net import net_util
+from convlab.lib import logger, math_util, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class SARSA(Algorithm): + ''' + Implementation of SARSA. + + Algorithm: + Repeat: + 1. Collect some examples by acting in the environment and store them in an on policy replay memory (either batch or episodic) + 2. For each example calculate the target (bootstrapped estimate of the discounted value of the state and action taken), y, using a neural network to approximate the Q function. s_t' is the next state following the action actually taken, a_t. a_t' is the action actually taken in the next state s_t'. + y_t = r_t + gamma * Q(s_t', a_t') + 4. For each example calculate the current estimate of the discounted value of the state and action taken + x_t = Q(s_t, a_t) + 5. Calculate L(x, y) where L is a regression loss (eg. mse) + 6. Calculate the gradient of L with respect to all the parameters in the network and update the network parameters using the gradient + + e.g. algorithm_spec + "algorithm": { + "name": "SARSA", + "action_pdtype": "default", + "action_policy": "boltzmann", + "explore_var_spec": { + "name": "linear_decay", + "start_val": 1.0, + "end_val": 0.1, + "start_step": 10, + "end_step": 1000, + }, + "gamma": 0.99, + "training_frequency": 10, + } + ''' + +
[docs] @lab_api + def init_algorithm_params(self): + '''Initialize other algorithm parameters.''' + # set default + util.set_attr(self, dict( + action_pdtype='default', + action_policy='default', + explore_var_spec=None, + )) + util.set_attr(self, self.algorithm_spec, [ + 'action_pdtype', + 'action_policy', + # explore_var is epsilon, tau or etc. depending on the action policy + # these control the trade off between exploration and exploitaton + 'explore_var_spec', + 'gamma', # the discount factor + 'training_frequency', # how often to train for batch training (once each training_frequency time steps) + ]) + self.to_train = 0 + self.action_policy = getattr(policy_util, self.action_policy) + self.explore_var_scheduler = policy_util.VarScheduler(self.explore_var_spec) + self.body.explore_var = self.explore_var_scheduler.start_val
+ +
[docs] @lab_api + def init_nets(self, global_nets=None): + '''Initialize the neural network used to learn the Q function from the spec''' + if 'Recurrent' in self.net_spec['type']: + self.net_spec.update(seq_len=self.net_spec['seq_len']) + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dim, out_dim) + self.net_names = ['net'] + # init net optimizer and its lr scheduler + self.optim = net_util.get_optim(self.net, self.net.optim_spec) + self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) + self.post_init_nets()
+ +
[docs] @lab_api + def calc_pdparam(self, x, net=None): + ''' + To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs. + The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist. + ''' + net = self.net if net is None else net + pdparam = net(x) + return pdparam
+ +
[docs] @lab_api + def act(self, state): + '''Note, SARSA is discrete-only''' + body = self.body + action = self.action_policy(state, self, body) + return action.cpu().squeeze().numpy() # squeeze to handle scalar
+ +
[docs] @lab_api + def sample(self): + '''Samples a batch from memory''' + batch = self.body.memory.sample() + # this is safe for next_action at done since the calculated act_next_q_preds will be multiplied by (1 - batch['dones']) + batch['next_actions'] = np.zeros_like(batch['actions']) + batch['next_actions'][:-1] = batch['actions'][1:] + batch = util.to_torch_batch(batch, self.net.device, self.body.memory.is_episodic) + return batch
+ +
[docs] def calc_q_loss(self, batch): + '''Compute the Q value loss using predicted and target Q values from the appropriate networks''' + states = batch['states'] + next_states = batch['next_states'] + if self.body.env.is_venv: + states = math_util.venv_unpack(states) + next_states = math_util.venv_unpack(next_states) + q_preds = self.net(states) + with torch.no_grad(): + next_q_preds = self.net(next_states) + if self.body.env.is_venv: + q_preds = math_util.venv_pack(q_preds, self.body.env.num_envs) + next_q_preds = math_util.venv_pack(next_q_preds, self.body.env.num_envs) + act_q_preds = q_preds.gather(-1, batch['actions'].long().unsqueeze(-1)).squeeze(-1) + act_next_q_preds = next_q_preds.gather(-1, batch['next_actions'].long().unsqueeze(-1)).squeeze(-1) + act_q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * act_next_q_preds + logger.debug(f'act_q_preds: {act_q_preds}\nact_q_targets: {act_q_targets}') + q_loss = self.net.loss_fn(act_q_preds, act_q_targets) + return q_loss
+ +
[docs] @lab_api + def train(self): + ''' + Completes one training step for the agent if it is time to train. + Otherwise this function does nothing. + ''' + if util.in_eval_lab_modes(): + return np.nan + clock = self.body.env.clock + if self.to_train == 1: + batch = self.sample() + clock.set_batch_size(len(batch)) + loss = self.calc_q_loss(batch) + self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + # reset + self.to_train = 0 + logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}') + return loss.item() + else: + return np.nan
+ +
[docs] @lab_api + def update(self): + '''Update the agent after training''' + self.body.explore_var = self.explore_var_scheduler.update(self, self.body.env.clock) + return self.body.explore_var
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/algorithm/sil.html b/docs/build/html/_modules/convlab/agent/algorithm/sil.html new file mode 100644 index 0000000..e00be8b --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/algorithm/sil.html @@ -0,0 +1,386 @@ + + + + + + + + + + + convlab.agent.algorithm.sil — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.algorithm.sil

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+import torch
+
+from convlab.agent import memory
+from convlab.agent.algorithm import policy_util
+from convlab.agent.algorithm.actor_critic import ActorCritic
+from convlab.agent.algorithm.ppo import PPO
+from convlab.lib import logger, math_util, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class SIL(ActorCritic): + ''' + Implementation of Self-Imitation Learning (SIL) https://arxiv.org/abs/1806.05635 + This is actually just A2C with an extra SIL loss function + + e.g. algorithm_spec + "algorithm": { + "name": "SIL", + "action_pdtype": "default", + "action_policy": "default", + "explore_var_spec": null, + "gamma": 0.99, + "lam": 1.0, + "num_step_returns": 100, + "entropy_coef_spec": { + "name": "linear_decay", + "start_val": 0.01, + "end_val": 0.001, + "start_step": 100, + "end_step": 5000, + }, + "policy_loss_coef": 1.0, + "val_loss_coef": 0.01, + "sil_policy_loss_coef": 1.0, + "sil_val_loss_coef": 0.01, + "training_batch_iter": 8, + "training_frequency": 1, + "training_iter": 8, + } + + e.g. special memory_spec + "memory": { + "name": "OnPolicyReplay", + "sil_replay_name": "Replay", + "batch_size": 32, + "max_size": 10000, + "use_cer": true + } + ''' + + def __init__(self, agent, global_nets=None): + super().__init__(agent, global_nets) + # create the extra replay memory for SIL + MemoryClass = getattr(memory, self.memory_spec['sil_replay_name']) + self.body.replay_memory = MemoryClass(self.memory_spec, self.body) + +
[docs] @lab_api + def init_algorithm_params(self): + '''Initialize other algorithm parameters''' + # set default + util.set_attr(self, dict( + action_pdtype='default', + action_policy='default', + explore_var_spec=None, + entropy_coef_spec=None, + policy_loss_coef=1.0, + val_loss_coef=1.0, + )) + util.set_attr(self, self.algorithm_spec, [ + 'action_pdtype', + 'action_policy', + # theoretically, AC does not have policy update; but in this implementation we have such option + 'explore_var_spec', + 'gamma', # the discount factor + 'lam', + 'num_step_returns', + 'entropy_coef_spec', + 'policy_loss_coef', + 'val_loss_coef', + 'sil_policy_loss_coef', + 'sil_val_loss_coef', + 'training_frequency', + 'training_batch_iter', + 'training_iter', + ]) + super().init_algorithm_params()
+ +
[docs] def sample(self): + '''Modify the onpolicy sample to also append to replay''' + batch = self.body.memory.sample() + batch = {k: np.concatenate(v) for k, v in batch.items()} # concat episodic memory + for idx in range(len(batch['dones'])): + tuples = [batch[k][idx] for k in self.body.replay_memory.data_keys] + self.body.replay_memory.add_experience(*tuples) + batch = util.to_torch_batch(batch, self.net.device, self.body.replay_memory.is_episodic) + return batch
+ +
[docs] def replay_sample(self): + '''Samples a batch from memory''' + batch = self.body.replay_memory.sample() + batch = util.to_torch_batch(batch, self.net.device, self.body.replay_memory.is_episodic) + return batch
+ +
[docs] def calc_sil_policy_val_loss(self, batch, pdparams): + ''' + Calculate the SIL policy losses for actor and critic + sil_policy_loss = -log_prob * max(R - v_pred, 0) + sil_val_loss = (max(R - v_pred, 0)^2) / 2 + This is called on a randomly-sample batch from experience replay + ''' + v_preds = self.calc_v(batch['states'], use_cache=False) + rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) + clipped_advs = torch.clamp(rets - v_preds, min=0.0) + + action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) + actions = batch['actions'] + if self.body.env.is_venv: + actions = math_util.venv_unpack(actions) + log_probs = action_pd.log_prob(actions) + + sil_policy_loss = - self.sil_policy_loss_coef * (log_probs * clipped_advs).mean() + sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2 + logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}') + logger.debug(f'SIL critic value loss: {sil_val_loss:g}') + return sil_policy_loss, sil_val_loss
+ +
[docs] def train(self): + clock = self.body.env.clock + if self.to_train == 1: + # onpolicy update + super_loss = super().train() + # offpolicy sil update with random minibatch + total_sil_loss = torch.tensor(0.0) + for _ in range(self.training_iter): + batch = self.replay_sample() + for _ in range(self.training_batch_iter): + pdparams, _v_preds = self.calc_pdparam_v(batch) + sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch, pdparams) + sil_loss = sil_policy_loss + sil_val_loss + self.net.train_step(sil_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) + total_sil_loss += sil_loss + sil_loss = total_sil_loss / self.training_iter + loss = super_loss + sil_loss + logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}') + return loss.item() + else: + return np.nan
+ + +
[docs]class PPOSIL(SIL, PPO): + ''' + SIL extended from PPO. This will call the SIL methods and use PPO as super(). + + e.g. algorithm_spec + "algorithm": { + "name": "PPOSIL", + "action_pdtype": "default", + "action_policy": "default", + "explore_var_spec": null, + "gamma": 0.99, + "lam": 1.0, + "clip_eps_spec": { + "name": "linear_decay", + "start_val": 0.01, + "end_val": 0.001, + "start_step": 100, + "end_step": 5000, + }, + "entropy_coef_spec": { + "name": "linear_decay", + "start_val": 0.01, + "end_val": 0.001, + "start_step": 100, + "end_step": 5000, + }, + "sil_policy_loss_coef": 1.0, + "sil_val_loss_coef": 0.01, + "training_frequency": 1, + "training_batch_iter": 8, + "training_iter": 8, + "training_epoch": 8, + } + + e.g. special memory_spec + "memory": { + "name": "OnPolicyReplay", + "sil_replay_name": "Replay", + "batch_size": 32, + "max_size": 10000, + "use_cer": true + } + ''' + pass
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/memory/base.html b/docs/build/html/_modules/convlab/agent/memory/base.html new file mode 100644 index 0000000..073a1cb --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/memory/base.html @@ -0,0 +1,223 @@ + + + + + + + + + + + convlab.agent.memory.base — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.memory.base

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+from abc import ABC, abstractmethod
+
+from convlab.lib import logger
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class Memory(ABC): + '''Abstract Memory class to define the API methods''' + + def __init__(self, memory_spec, body): + ''' + @param {*} body is the unit that stores its experience in this memory. Each body has a distinct memory. + ''' + self.memory_spec = memory_spec + self.body = body + # declare what data keys to store + self.data_keys = ['states', 'actions', 'rewards', 'next_states', 'dones', 'priorities'] + +
[docs] @abstractmethod + def reset(self): + '''Method to fully reset the memory storage and related variables''' + raise NotImplementedError
+ +
[docs] @abstractmethod + def update(self, state, action, reward, next_state, done): + '''Implement memory update given the full info from the latest timestep. NOTE: guard for np.nan reward and done when individual env resets.''' + raise NotImplementedError
+ +
[docs] @abstractmethod + def sample(self): + '''Implement memory sampling mechanism''' + raise NotImplementedError
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/memory/onpolicy.html b/docs/build/html/_modules/convlab/agent/memory/onpolicy.html new file mode 100644 index 0000000..26c5658 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/memory/onpolicy.html @@ -0,0 +1,333 @@ + + + + + + + + + + + convlab.agent.memory.onpolicy — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.memory.onpolicy

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+from convlab.agent.memory.base import Memory
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]class OnPolicyReplay(Memory): + ''' + Stores agent experiences and returns them in a batch for agent training. + + An experience consists of + - state: representation of a state + - action: action taken + - reward: scalar value + - next state: representation of next state (should be same as state) + - done: 0 / 1 representing if the current state is the last in an episode + + The memory does not have a fixed size. Instead the memory stores data from N episodes, where N is determined by the user. After N episodes, all of the examples are returned to the agent to learn from. + + When the examples are returned to the agent, the memory is cleared to prevent the agent from learning from off policy experiences. This memory is intended for on policy algorithms. + + Differences vs. Replay memory: + - Experiences are nested into episodes. In Replay experiences are flat, and episode is not tracked + - The entire memory constitues a batch. In Replay batches are sampled from memory. + - The memory is cleared automatically when a batch is given to the agent. + + e.g. memory_spec + "memory": { + "name": "OnPolicyReplay" + } + ''' + + def __init__(self, memory_spec, body): + super().__init__(memory_spec, body) + # NOTE for OnPolicy replay, frequency = episode; for other classes below frequency = frames + util.set_attr(self, self.body.agent.agent_spec['algorithm'], ['training_frequency']) + # Don't want total experiences reset when memory is + self.is_episodic = True + self.size = 0 # total experiences stored + self.seen_size = 0 # total experiences seen cumulatively + # declare what data keys to store + self.data_keys = ['states', 'actions', 'rewards', 'next_states', 'dones'] + self.reset() + +
[docs] @lab_api + def reset(self): + '''Resets the memory. Also used to initialize memory vars''' + for k in self.data_keys: + setattr(self, k, []) + self.cur_epi_data = {k: [] for k in self.data_keys} + self.most_recent = (None,) * len(self.data_keys) + self.size = 0
+ +
[docs] @lab_api + def update(self, state, action, reward, next_state, done): + '''Interface method to update memory''' + self.add_experience(state, action, reward, next_state, done)
+ +
[docs] def add_experience(self, state, action, reward, next_state, done): + '''Interface helper method for update() to add experience to memory''' + self.most_recent = (state, action, reward, next_state, done) + for idx, k in enumerate(self.data_keys): + self.cur_epi_data[k].append(self.most_recent[idx]) + # If episode ended, add to memory and clear cur_epi_data + if util.epi_done(done): + for k in self.data_keys: + getattr(self, k).append(self.cur_epi_data[k]) + self.cur_epi_data = {k: [] for k in self.data_keys} + # If agent has collected the desired number of episodes, it is ready to train + # length is num of epis due to nested structure + # if len(self.states) == self.body.agent.algorithm.training_frequency: + if len(self.states) % self.body.agent.algorithm.training_frequency == 0: + self.body.agent.algorithm.to_train = 1 + # Track memory size and num experiences + self.size += 1 + self.seen_size += 1
+ +
[docs] def get_most_recent_experience(self): + '''Returns the most recent experience''' + return self.most_recent
+ +
[docs] def sample(self): + ''' + Returns all the examples from memory in a single batch. Batch is stored as a dict. + Keys are the names of the different elements of an experience. Values are nested lists of the corresponding sampled elements. Elements are nested into episodes + e.g. + batch = { + 'states' : [[s_epi1], [s_epi2], ...], + 'actions' : [[a_epi1], [a_epi2], ...], + 'rewards' : [[r_epi1], [r_epi2], ...], + 'next_states': [[ns_epi1], [ns_epi2], ...], + 'dones' : [[d_epi1], [d_epi2], ...]} + ''' + batch = {k: getattr(self, k) for k in self.data_keys} + self.reset() + return batch
+ + +
[docs]class OnPolicyBatchReplay(OnPolicyReplay): + ''' + Same as OnPolicyReplay Memory with the following difference. + + The memory does not have a fixed size. Instead the memory stores data from N experiences, where N is determined by the user. After N experiences or if an episode has ended, all of the examples are returned to the agent to learn from. + + In contrast, OnPolicyReplay stores entire episodes and stores them in a nested structure. OnPolicyBatchReplay stores experiences in a flat structure. + + e.g. memory_spec + "memory": { + "name": "OnPolicyBatchReplay" + } + * batch_size is training_frequency provided by algorithm_spec + ''' + + def __init__(self, memory_spec, body): + super().__init__(memory_spec, body) + self.is_episodic = False + +
[docs] def add_experience(self, state, action, reward, next_state, done): + '''Interface helper method for update() to add experience to memory''' + self.most_recent = [state, action, reward, next_state, done] + for idx, k in enumerate(self.data_keys): + getattr(self, k).append(self.most_recent[idx]) + # Track memory size and num experiences + self.size += 1 + self.seen_size += 1 + # Decide if agent is to train + if len(self.states) == self.body.agent.algorithm.training_frequency: + self.body.agent.algorithm.to_train = 1
+ +
[docs] def sample(self): + ''' + Returns all the examples from memory in a single batch. Batch is stored as a dict. + Keys are the names of the different elements of an experience. Values are a list of the corresponding sampled elements + e.g. + batch = { + 'states' : states, + 'actions' : actions, + 'rewards' : rewards, + 'next_states': next_states, + 'dones' : dones} + ''' + return super().sample()
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/memory/prioritized.html b/docs/build/html/_modules/convlab/agent/memory/prioritized.html new file mode 100644 index 0000000..6f07446 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/memory/prioritized.html @@ -0,0 +1,359 @@ + + + + + + + + + + + convlab.agent.memory.prioritized — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.memory.prioritized

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import random
+
+import numpy as np
+
+from convlab.agent.memory.replay import Replay
+from convlab.lib import util
+
+
+
[docs]class SumTree: + ''' + Helper class for PrioritizedReplay + + This implementation is, with minor adaptations, Jaromír Janisch's. The license is reproduced below. + For more information see his excellent blog series "Let's make a DQN" https://jaromiru.com/2016/09/27/lets-make-a-dqn-theory/ + + MIT License + + Copyright (c) 2018 Jaromír Janisch + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + ''' + write = 0 + + def __init__(self, capacity): + self.capacity = capacity + self.tree = np.zeros(2 * capacity - 1) # Stores the priorities and sums of priorities + self.indices = np.zeros(capacity) # Stores the indices of the experiences + + def _propagate(self, idx, change): + parent = (idx - 1) // 2 + + self.tree[parent] += change + + if parent != 0: + self._propagate(parent, change) + + def _retrieve(self, idx, s): + left = 2 * idx + 1 + right = left + 1 + + if left >= len(self.tree): + return idx + + if s <= self.tree[left]: + return self._retrieve(left, s) + else: + return self._retrieve(right, s - self.tree[left]) + +
[docs] def total(self): + return self.tree[0]
+ +
[docs] def add(self, p, index): + idx = self.write + self.capacity - 1 + + self.indices[self.write] = index + self.update(idx, p) + + self.write += 1 + if self.write >= self.capacity: + self.write = 0
+ +
[docs] def update(self, idx, p): + change = p - self.tree[idx] + + self.tree[idx] = p + self._propagate(idx, change)
+ +
[docs] def get(self, s): + assert s <= self.total() + idx = self._retrieve(0, s) + indexIdx = idx - self.capacity + 1 + + return (idx, self.tree[idx], self.indices[indexIdx])
+ +
[docs] def print_tree(self): + for i in range(len(self.indices)): + j = i + self.capacity - 1 + print(f'Idx: {i}, Data idx: {self.indices[i]}, Prio: {self.tree[j]}')
+ + +
[docs]class PrioritizedReplay(Replay): + ''' + Prioritized Experience Replay + + Implementation follows the approach in the paper "Prioritized Experience Replay", Schaul et al 2015" https://arxiv.org/pdf/1511.05952.pdf and is Jaromír Janisch's with minor adaptations. + See memory_util.py for the license and link to Jaromír's excellent blog + + Stores agent experiences and samples from them for agent training according to each experience's priority + + The memory has the same behaviour and storage structure as Replay memory with the addition of a SumTree to store and sample the priorities. + + e.g. memory_spec + "memory": { + "name": "PrioritizedReplay", + "alpha": 1, + "epsilon": 0, + "batch_size": 32, + "max_size": 10000, + "use_cer": true + } + ''' + + def __init__(self, memory_spec, body): + util.set_attr(self, memory_spec, [ + 'alpha', + 'epsilon', + 'batch_size', + 'max_size', + 'use_cer', + ]) + super().__init__(memory_spec, body) + + self.epsilon = np.full((1,), self.epsilon) + self.alpha = np.full((1,), self.alpha) + # adds a 'priorities' scalar to the data_keys and call reset again + self.data_keys = ['states', 'actions', 'rewards', 'next_states', 'dones', 'priorities'] + self.reset() + +
[docs] def reset(self): + super().reset() + self.tree = SumTree(self.max_size)
+ +
[docs] def add_experience(self, state, action, reward, next_state, done, error=100000): + ''' + Implementation for update() to add experience to memory, expanding the memory size if necessary. + All experiences are added with a high priority to increase the likelihood that they are sampled at least once. + ''' + super().add_experience(state, action, reward, next_state, done) + priority = self.get_priority(error) + self.priorities[self.head] = priority + self.tree.add(priority, self.head)
+ +
[docs] def get_priority(self, error): + '''Takes in the error of one or more examples and returns the proportional priority''' + return np.power(error + self.epsilon, self.alpha).squeeze()
+ +
[docs] def sample_idxs(self, batch_size): + '''Samples batch_size indices from memory in proportional to their priority.''' + batch_idxs = np.zeros(batch_size) + tree_idxs = np.zeros(batch_size, dtype=np.int) + + for i in range(batch_size): + s = random.uniform(0, self.tree.total()) + (tree_idx, p, idx) = self.tree.get(s) + batch_idxs[i] = idx + tree_idxs[i] = tree_idx + + batch_idxs = np.asarray(batch_idxs).astype(int) + self.tree_idxs = tree_idxs + if self.use_cer: # add the latest sample + batch_idxs[-1] = self.head + return batch_idxs
+ +
[docs] def update_priorities(self, errors): + ''' + Updates the priorities from the most recent batch + Assumes the relevant batch indices are stored in self.batch_idxs + ''' + priorities = self.get_priority(errors) + assert len(priorities) == self.batch_idxs.size + for idx, p in zip(self.batch_idxs, priorities): + self.priorities[idx] = p + for p, i in zip(priorities, self.tree_idxs): + self.tree.update(i, p)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/memory/replay.html b/docs/build/html/_modules/convlab/agent/memory/replay.html new file mode 100644 index 0000000..a4e89a0 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/memory/replay.html @@ -0,0 +1,345 @@ + + + + + + + + + + + convlab.agent.memory.replay — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.memory.replay

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+
+from convlab.agent.memory.base import Memory
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]def sample_next_states(head, max_size, ns_idx_offset, batch_idxs, states, ns_buffer): + '''Method to sample next_states from states, with proper guard for next_state idx being out of bound''' + # idxs for next state is state idxs with offset, modded + ns_batch_idxs = (batch_idxs + ns_idx_offset) % max_size + # if head < ns_idx <= head + ns_idx_offset, ns is stored in ns_buffer + ns_batch_idxs = ns_batch_idxs % max_size + buffer_ns_locs = np.argwhere( + (head < ns_batch_idxs) & (ns_batch_idxs <= head + ns_idx_offset)).flatten() + # find if there is any idxs to get from buffer + to_replace = buffer_ns_locs.size != 0 + if to_replace: + # extract the buffer_idxs first for replacement later + # given head < ns_idx <= head + offset, and valid buffer idx is [0, offset) + # get 0 < ns_idx - head <= offset, or equiv. + # get -1 < ns_idx - head - 1 <= offset - 1, i.e. + # get 0 <= ns_idx - head - 1 < offset, hence: + buffer_idxs = ns_batch_idxs[buffer_ns_locs] - head - 1 + # set them to 0 first to allow sampling, then replace later with buffer + ns_batch_idxs[buffer_ns_locs] = 0 + # guard all against overrun idxs from offset + ns_batch_idxs = ns_batch_idxs % max_size + next_states = util.batch_get(states, ns_batch_idxs) + if to_replace: + # now replace using buffer_idxs and ns_buffer + buffer_ns = util.batch_get(ns_buffer, buffer_idxs) + next_states[buffer_ns_locs] = buffer_ns + return next_states
+ + +
[docs]class Replay(Memory): + ''' + Stores agent experiences and samples from them for agent training + + An experience consists of + - state: representation of a state + - action: action taken + - reward: scalar value + - next state: representation of next state (should be same as state) + - done: 0 / 1 representing if the current state is the last in an episode + + The memory has a size of N. When capacity is reached, the oldest experience + is deleted to make space for the lastest experience. + - This is implemented as a circular buffer so that inserting experiences are O(1) + - Each element of an experience is stored as a separate array of size N * element dim + + When a batch of experiences is requested, K experiences are sampled according to a random uniform distribution. + + If 'use_cer', sampling will add the latest experience. + + e.g. memory_spec + "memory": { + "name": "Replay", + "batch_size": 32, + "max_size": 10000, + "use_cer": true + } + ''' + + def __init__(self, memory_spec, body): + super().__init__(memory_spec, body) + util.set_attr(self, self.memory_spec, [ + 'batch_size', + 'max_size', + 'use_cer', + ]) + self.is_episodic = False + self.batch_idxs = None + self.size = 0 # total experiences stored + self.seen_size = 0 # total experiences seen cumulatively + self.head = -1 # index of most recent experience + # generic next_state buffer to store last next_states (allow for multiple for venv) + # self.ns_idx_offset = self.body.env.num_envs if body.env.is_venv else 1 + # self.ns_buffer = deque(maxlen=self.ns_idx_offset) + # declare what data keys to store + self.data_keys = ['states', 'actions', 'rewards', 'next_states', 'dones'] + self.reset() + +
[docs] def reset(self): + '''Initializes the memory arrays, size and head pointer''' + # set self.states, self.actions, ... + for k in self.data_keys: + setattr(self, k, [None] * self.max_size) + # if k != 'next_states': # reuse self.states + # # list add/sample is over 10x faster than np, also simpler to handle + # setattr(self, k, [None] * self.max_size) + self.size = 0 + self.head = -1
+ # self.ns_buffer.clear() + +
[docs] @lab_api + def update(self, state, action, reward, next_state, done): + '''Interface method to update memory''' + if self.body.env.is_venv: + for sarsd in zip(state, action, reward, next_state, done): + self.add_experience(*sarsd) + else: + self.add_experience(state, action, reward, next_state, done)
+ +
[docs] def add_experience(self, state, action, reward, next_state, done): + '''Implementation for update() to add experience to memory, expanding the memory size if necessary''' + # Move head pointer. Wrap around if necessary + self.head = (self.head + 1) % self.max_size + self.states[self.head] = state.astype(np.float16) + self.actions[self.head] = action + self.rewards[self.head] = reward + self.next_states[self.head] = next_state + # self.ns_buffer.append(next_state.astype(np.float16)) + self.dones[self.head] = done + + # Actually occupied size of memory + if self.size < self.max_size: + self.size += 1 + self.seen_size += 1 + # set to_train using memory counters head, seen_size instead of tick since clock will step by num_envs when on venv; to_train will be set to 0 after training step + algorithm = self.body.agent.algorithm + algorithm.to_train = algorithm.to_train or (self.seen_size > algorithm.training_start_step and self.head % algorithm.training_frequency == 0)
+ +
[docs] @lab_api + def sample(self): + ''' + Returns a batch of batch_size samples. Batch is stored as a dict. + Keys are the names of the different elements of an experience. Values are an array of the corresponding sampled elements + e.g. + batch = { + 'states' : states, + 'actions' : actions, + 'rewards' : rewards, + 'next_states': next_states, + 'dones' : dones} + ''' + self.batch_idxs = self.sample_idxs(self.batch_size) + batch = {} + for k in self.data_keys: + batch[k] = util.batch_get(getattr(self, k), self.batch_idxs) + # if k == 'next_states': + # batch[k] = sample_next_states(self.head, self.max_size, self.ns_idx_offset, self.batch_idxs, self.states, self.ns_buffer) + # else: + # batch[k] = util.batch_get(getattr(self, k), self.batch_idxs) + return batch
+ +
[docs] def sample_idxs(self, batch_size): + '''Batch indices a sampled random uniformly''' + batch_idxs = np.random.randint(self.size, size=batch_size) + if self.use_cer: # add the latest sample + batch_idxs[-1] = self.head + return batch_idxs
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/net/base.html b/docs/build/html/_modules/convlab/agent/net/base.html new file mode 100644 index 0000000..ce3818b --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/net/base.html @@ -0,0 +1,239 @@ + + + + + + + + + + + convlab.agent.net.base — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.net.base

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+from abc import ABC
+
+import pydash as ps
+import torch
+import torch.nn as nn
+
+from convlab.agent.net import net_util
+
+
+
[docs]class Net(ABC): + '''Abstract Net class to define the API methods''' + + def __init__(self, net_spec, in_dim, out_dim): + ''' + @param {dict} net_spec is the spec for the net + @param {int|list} in_dim is the input dimension(s) for the network. Usually use in_dim=body.state_dim + @param {int|list} out_dim is the output dimension(s) for the network. Usually use out_dim=body.action_dim + ''' + self.net_spec = net_spec + self.in_dim = in_dim + self.out_dim = out_dim + self.grad_norms = None # for debugging + if self.net_spec.get('gpu'): + if torch.cuda.device_count(): + self.device = f'cuda:{net_spec.get("cuda_id", 0)}' + else: + self.device = 'cpu' + else: + self.device = 'cpu' + +
[docs] @net_util.dev_check_train_step + def train_step(self, loss, optim, lr_scheduler, clock=None, global_net=None): + lr_scheduler.step(epoch=ps.get(clock, 'frame')) + optim.zero_grad() + loss.backward() + if self.clip_grad_val is not None: + nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val) + if global_net is not None: + net_util.push_global_grads(self, global_net) + optim.step() + if global_net is not None: + net_util.copy(global_net, self) + clock.tick('opt_step') + return loss
+ +
[docs] def store_grad_norms(self): + '''Stores the gradient norms for debugging.''' + norms = [param.grad.norm().item() for param in self.parameters()] + self.grad_norms = norms
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/net/conv.html b/docs/build/html/_modules/convlab/agent/net/conv.html new file mode 100644 index 0000000..8453a2e --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/net/conv.html @@ -0,0 +1,498 @@ + + + + + + + + + + + convlab.agent.net.conv — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.net.conv

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import pydash as ps
+import torch
+import torch.nn as nn
+
+from convlab.agent.net import net_util
+from convlab.agent.net.base import Net
+from convlab.lib import math_util, util
+
+
+
[docs]class ConvNet(Net, nn.Module): + ''' + Class for generating arbitrary sized convolutional neural network, + with optional batch normalization + + Assumes that a single input example is organized into a 3D tensor. + The entire model consists of three parts: + 1. self.conv_model + 2. self.fc_model + 3. self.model_tails + + e.g. net_spec + "net": { + "type": "ConvNet", + "shared": true, + "conv_hid_layers": [ + [32, 8, 4, 0, 1], + [64, 4, 2, 0, 1], + [64, 3, 1, 0, 1] + ], + "fc_hid_layers": [512], + "hid_layers_activation": "relu", + "out_layer_activation": "tanh", + "init_fn": null, + "normalize": false, + "batch_norm": false, + "clip_grad_val": 1.0, + "loss_spec": { + "name": "SmoothL1Loss" + }, + "optim_spec": { + "name": "Adam", + "lr": 0.02 + }, + "lr_scheduler_spec": { + "name": "StepLR", + "step_size": 30, + "gamma": 0.1 + }, + "update_type": "replace", + "update_frequency": 10000, + "polyak_coef": 0.9, + "gpu": true + } + ''' + + def __init__(self, net_spec, in_dim, out_dim): + ''' + net_spec: + conv_hid_layers: list containing dimensions of the convolutional hidden layers, each is a list representing hid_layer = out_d, kernel, stride, padding, dilation. + Asssumed to all come before the flat layers. + Note: a convolutional layer should specify the in_channel, out_channels, kernel_size, stride (of kernel steps), padding, and dilation (spacing between kernel points) E.g. [3, 16, (5, 5), 1, 0, (2, 2)] + For more details, see http://pytorch.org/docs/master/nn.html#conv2d and https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + fc_hid_layers: list of fc layers following the convolutional layers + hid_layers_activation: activation function for the hidden layers + out_layer_activation: activation function for the output layer, same shape as out_dim + init_fn: weight initialization function + normalize: whether to divide by 255.0 to normalize image input + batch_norm: whether to add batch normalization after each convolutional layer, excluding the input layer. + clip_grad_val: clip gradient norm if value is not None + loss_spec: measure of error between model predictions and correct outputs + optim_spec: parameters for initializing the optimizer + lr_scheduler_spec: Pytorch optim.lr_scheduler + update_type: method to update network weights: 'replace' or 'polyak' + update_frequency: how many total timesteps per update + polyak_coef: ratio of polyak weight update + gpu: whether to train using a GPU. Note this will only work if a GPU is available, othewise setting gpu=True does nothing + ''' + assert len(in_dim) == 3 # image shape (c,w,h) + nn.Module.__init__(self) + super().__init__(net_spec, in_dim, out_dim) + # set default + util.set_attr(self, dict( + out_layer_activation=None, + init_fn=None, + normalize=False, + batch_norm=True, + clip_grad_val=None, + loss_spec={'name': 'MSELoss'}, + optim_spec={'name': 'Adam'}, + lr_scheduler_spec=None, + update_type='replace', + update_frequency=1, + polyak_coef=0.0, + gpu=False, + )) + util.set_attr(self, self.net_spec, [ + 'conv_hid_layers', + 'fc_hid_layers', + 'hid_layers_activation', + 'out_layer_activation', + 'init_fn', + 'normalize', + 'batch_norm', + 'clip_grad_val', + 'loss_spec', + 'optim_spec', + 'lr_scheduler_spec', + 'update_type', + 'update_frequency', + 'polyak_coef', + 'gpu', + ]) + + # conv body + self.conv_model = self.build_conv_layers(self.conv_hid_layers) + self.conv_out_dim = self.get_conv_output_size() + + # fc body + if ps.is_empty(self.fc_hid_layers): + tail_in_dim = self.conv_out_dim + else: + # fc body from flattened conv + self.fc_model = net_util.build_fc_model([self.conv_out_dim] + self.fc_hid_layers, self.hid_layers_activation) + tail_in_dim = self.fc_hid_layers[-1] + + # tails. avoid list for single-tail for compute speed + if ps.is_integer(self.out_dim): + self.model_tail = net_util.build_fc_model([tail_in_dim, self.out_dim], self.out_layer_activation) + else: + if not ps.is_list(self.out_layer_activation): + self.out_layer_activation = [self.out_layer_activation] * len(out_dim) + assert len(self.out_layer_activation) == len(self.out_dim) + tails = [] + for out_d, out_activ in zip(self.out_dim, self.out_layer_activation): + tail = net_util.build_fc_model([tail_in_dim, out_d], out_activ) + tails.append(tail) + self.model_tails = nn.ModuleList(tails) + + net_util.init_layers(self, self.init_fn) + self.loss_fn = net_util.get_loss_fn(self, self.loss_spec) + self.to(self.device) + self.train() + +
[docs] def get_conv_output_size(self): + '''Helper function to calculate the size of the flattened features after the final convolutional layer''' + with torch.no_grad(): + x = torch.ones(1, *self.in_dim) + x = self.conv_model(x) + return x.numel()
+ +
[docs] def build_conv_layers(self, conv_hid_layers): + ''' + Builds all of the convolutional layers in the network and store in a Sequential model + ''' + conv_layers = [] + in_d = self.in_dim[0] # input channel + for i, hid_layer in enumerate(conv_hid_layers): + hid_layer = [tuple(e) if ps.is_list(e) else e for e in hid_layer] # guard list-to-tuple + # hid_layer = out_d, kernel, stride, padding, dilation + conv_layers.append(nn.Conv2d(in_d, *hid_layer)) + if self.hid_layers_activation is not None: + conv_layers.append(net_util.get_activation_fn(self.hid_layers_activation)) + # Don't include batch norm in the first layer + if self.batch_norm and i != 0: + conv_layers.append(nn.BatchNorm2d(in_d)) + in_d = hid_layer[0] # update to out_d + conv_model = nn.Sequential(*conv_layers) + return conv_model
+ +
[docs] def forward(self, x): + ''' + The feedforward step + Note that PyTorch takes (c,h,w) but gym provides (h,w,c), so preprocessing must be done before passing to network + ''' + if self.normalize: + x = x / 255.0 + x = self.conv_model(x) + x = x.view(x.size(0), -1) # to (batch_size, -1) + if hasattr(self, 'fc_model'): + x = self.fc_model(x) + # return tensor if single tail, else list of tail tensors + if hasattr(self, 'model_tails'): + outs = [] + for model_tail in self.model_tails: + outs.append(model_tail(x)) + return outs + else: + return self.model_tail(x)
+ + +
[docs]class DuelingConvNet(ConvNet): + ''' + Class for generating arbitrary sized convolutional neural network, + with optional batch normalization, and with dueling heads. Intended for Q-Learning algorithms only. + Implementation based on "Dueling Network Architectures for Deep Reinforcement Learning" http://proceedings.mlr.press/v48/wangf16.pdf + + Assumes that a single input example is organized into a 3D tensor. + The entire model consists of three parts: + 1. self.conv_model + 2. self.fc_model + 3. self.model_tails + + e.g. net_spec + "net": { + "type": "DuelingConvNet", + "shared": true, + "conv_hid_layers": [ + [32, 8, 4, 0, 1], + [64, 4, 2, 0, 1], + [64, 3, 1, 0, 1] + ], + "fc_hid_layers": [512], + "hid_layers_activation": "relu", + "init_fn": "xavier_uniform_", + "normalize": false, + "batch_norm": false, + "clip_grad_val": 1.0, + "loss_spec": { + "name": "SmoothL1Loss" + }, + "optim_spec": { + "name": "Adam", + "lr": 0.02 + }, + "lr_scheduler_spec": { + "name": "StepLR", + "step_size": 30, + "gamma": 0.1 + }, + "update_type": "replace", + "update_frequency": 10000, + "polyak_coef": 0.9, + "gpu": true + } + ''' + + def __init__(self, net_spec, in_dim, out_dim): + assert len(in_dim) == 3 # image shape (c,w,h) + nn.Module.__init__(self) + Net.__init__(self, net_spec, in_dim, out_dim) + # set default + util.set_attr(self, dict( + init_fn=None, + normalize=False, + batch_norm=False, + clip_grad_val=None, + loss_spec={'name': 'MSELoss'}, + optim_spec={'name': 'Adam'}, + lr_scheduler_spec=None, + update_type='replace', + update_frequency=1, + polyak_coef=0.0, + gpu=False, + )) + util.set_attr(self, self.net_spec, [ + 'conv_hid_layers', + 'fc_hid_layers', + 'hid_layers_activation', + 'init_fn', + 'normalize', + 'batch_norm', + 'clip_grad_val', + 'loss_spec', + 'optim_spec', + 'lr_scheduler_spec', + 'update_type', + 'update_frequency', + 'polyak_coef', + 'gpu', + ]) + + # Guard against inappropriate algorithms and environments + assert isinstance(out_dim, int) + + # conv body + self.conv_model = self.build_conv_layers(self.conv_hid_layers) + self.conv_out_dim = self.get_conv_output_size() + + # fc body + if ps.is_empty(self.fc_hid_layers): + tail_in_dim = self.conv_out_dim + else: + # fc layer from flattened conv + self.fc_model = net_util.build_fc_model([self.conv_out_dim] + self.fc_hid_layers, self.hid_layers_activation) + tail_in_dim = self.fc_hid_layers[-1] + + # tails. avoid list for single-tail for compute speed + self.v = nn.Linear(tail_in_dim, 1) # state value + self.adv = nn.Linear(tail_in_dim, out_dim) # action dependent raw advantage + self.model_tails = nn.ModuleList(self.v, self.adv) + + net_util.init_layers(self, self.init_fn) + self.loss_fn = net_util.get_loss_fn(self, self.loss_spec) + self.to(self.device) + self.train() + +
[docs] def forward(self, x): + '''The feedforward step''' + if self.normalize: + x = x / 255.0 + x = self.conv_model(x) + x = x.view(x.size(0), -1) # to (batch_size, -1) + if hasattr(self, 'fc_model'): + x = self.fc_model(x) + state_value = self.v(x) + raw_advantages = self.adv(x) + out = math_util.calc_q_value_logits(state_value, raw_advantages) + return out
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/net/mlp.html b/docs/build/html/_modules/convlab/agent/net/mlp.html new file mode 100644 index 0000000..4879c17 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/net/mlp.html @@ -0,0 +1,546 @@ + + + + + + + + + + + convlab.agent.net.mlp — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.net.mlp

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+import pydash as ps
+import torch
+import torch.nn as nn
+
+from convlab.agent.net import net_util
+from convlab.agent.net.base import Net
+from convlab.lib import math_util, util
+
+
+
[docs]class MLPNet(Net, nn.Module): + ''' + Class for generating arbitrary sized feedforward neural network + If more than 1 output tensors, will create a self.model_tails instead of making last layer part of self.model + + e.g. net_spec + "net": { + "type": "MLPNet", + "shared": true, + "hid_layers": [32], + "hid_layers_activation": "relu", + "out_layer_activation": null, + "init_fn": "xavier_uniform_", + "clip_grad_val": 1.0, + "loss_spec": { + "name": "MSELoss" + }, + "optim_spec": { + "name": "Adam", + "lr": 0.02 + }, + "lr_scheduler_spec": { + "name": "StepLR", + "step_size": 30, + "gamma": 0.1 + }, + "update_type": "replace", + "update_frequency": 1, + "polyak_coef": 0.9, + "gpu": true + } + ''' + + def __init__(self, net_spec, in_dim, out_dim): + ''' + net_spec: + hid_layers: list containing dimensions of the hidden layers + hid_layers_activation: activation function for the hidden layers + out_layer_activation: activation function for the output layer, same shape as out_dim + init_fn: weight initialization function + clip_grad_val: clip gradient norm if value is not None + loss_spec: measure of error between model predictions and correct outputs + optim_spec: parameters for initializing the optimizer + lr_scheduler_spec: Pytorch optim.lr_scheduler + update_type: method to update network weights: 'replace' or 'polyak' + update_frequency: how many total timesteps per update + polyak_coef: ratio of polyak weight update + gpu: whether to train using a GPU. Note this will only work if a GPU is available, othewise setting gpu=True does nothing + ''' + nn.Module.__init__(self) + super().__init__(net_spec, in_dim, out_dim) + # set default + util.set_attr(self, dict( + out_layer_activation=None, + init_fn=None, + clip_grad_val=None, + loss_spec={'name': 'MSELoss'}, + optim_spec={'name': 'Adam'}, + lr_scheduler_spec=None, + update_type='replace', + update_frequency=1, + polyak_coef=0.0, + gpu=False, + )) + util.set_attr(self, self.net_spec, [ + 'shared', + 'hid_layers', + 'hid_layers_activation', + 'out_layer_activation', + 'init_fn', + 'clip_grad_val', + 'loss_spec', + 'optim_spec', + 'lr_scheduler_spec', + 'update_type', + 'update_frequency', + 'polyak_coef', + 'gpu', + ]) + + dims = [self.in_dim] + self.hid_layers + self.model = net_util.build_fc_model(dims, self.hid_layers_activation) + # add last layer with no activation + # tails. avoid list for single-tail for compute speed + if ps.is_integer(self.out_dim): + self.model_tail = net_util.build_fc_model([dims[-1], self.out_dim], self.out_layer_activation) + else: + if not ps.is_list(self.out_layer_activation): + self.out_layer_activation = [self.out_layer_activation] * len(out_dim) + assert len(self.out_layer_activation) == len(self.out_dim) + tails = [] + for out_d, out_activ in zip(self.out_dim, self.out_layer_activation): + tail = net_util.build_fc_model([dims[-1], out_d], out_activ) + tails.append(tail) + self.model_tails = nn.ModuleList(tails) + + net_util.init_layers(self, self.init_fn) + self.loss_fn = net_util.get_loss_fn(self, self.loss_spec) + self.to(self.device) + self.train() + +
[docs] def forward(self, x): + '''The feedforward step''' + x = self.model(x) + if hasattr(self, 'model_tails'): + outs = [] + for model_tail in self.model_tails: + outs.append(model_tail(x)) + return outs + else: + return self.model_tail(x)
+ + +
[docs]class HydraMLPNet(Net, nn.Module): + ''' + Class for generating arbitrary sized feedforward neural network with multiple state and action heads, and a single shared body. + + e.g. net_spec + "net": { + "type": "HydraMLPNet", + "shared": true, + "hid_layers": [ + [[32],[32]], # 2 heads with hidden layers + [64], # body + [] # tail, no hidden layers + ], + "hid_layers_activation": "relu", + "out_layer_activation": null, + "init_fn": "xavier_uniform_", + "clip_grad_val": 1.0, + "loss_spec": { + "name": "MSELoss" + }, + "optim_spec": { + "name": "Adam", + "lr": 0.02 + }, + "lr_scheduler_spec": { + "name": "StepLR", + "step_size": 30, + "gamma": 0.1 + }, + "update_type": "replace", + "update_frequency": 1, + "polyak_coef": 0.9, + "gpu": true + } + ''' + + def __init__(self, net_spec, in_dim, out_dim): + ''' + Multi state processing heads, single shared body, and multi action tails. + There is one state and action head per body/environment + Example: + + env 1 state env 2 state + _______|______ _______|______ + | head 1 | | head 2 | + |______________| |______________| + | | + |__________________| + ________________|_______________ + | Shared body | + |________________________________| + | + ________|_______ + | | + _______|______ ______|_______ + | tail 1 | | tail 2 | + |______________| |______________| + | | + env 1 action env 2 action + ''' + nn.Module.__init__(self) + super().__init__(net_spec, in_dim, out_dim) + # set default + util.set_attr(self, dict( + out_layer_activation=None, + init_fn=None, + clip_grad_val=None, + loss_spec={'name': 'MSELoss'}, + optim_spec={'name': 'Adam'}, + lr_scheduler_spec=None, + update_type='replace', + update_frequency=1, + polyak_coef=0.0, + gpu=False, + )) + util.set_attr(self, self.net_spec, [ + 'hid_layers', + 'hid_layers_activation', + 'out_layer_activation', + 'init_fn', + 'clip_grad_val', + 'loss_spec', + 'optim_spec', + 'lr_scheduler_spec', + 'update_type', + 'update_frequency', + 'polyak_coef', + 'gpu', + ]) + assert len(self.hid_layers) == 3, 'Your hidden layers must specify [*heads], [body], [*tails]. If not, use MLPNet' + assert isinstance(self.in_dim, list), 'Hydra network needs in_dim as list' + assert isinstance(self.out_dim, list), 'Hydra network needs out_dim as list' + self.head_hid_layers = self.hid_layers[0] + self.body_hid_layers = self.hid_layers[1] + self.tail_hid_layers = self.hid_layers[2] + if len(self.head_hid_layers) == 1: + self.head_hid_layers = self.head_hid_layers * len(self.in_dim) + if len(self.tail_hid_layers) == 1: + self.tail_hid_layers = self.tail_hid_layers * len(self.out_dim) + + self.model_heads = self.build_model_heads(in_dim) + heads_out_dim = np.sum([head_hid_layers[-1] for head_hid_layers in self.head_hid_layers]) + dims = [heads_out_dim] + self.body_hid_layers + self.model_body = net_util.build_fc_model(dims, self.hid_layers_activation) + self.model_tails = self.build_model_tails(self.out_dim, self.out_layer_activation) + + net_util.init_layers(self, self.init_fn) + self.loss_fn = net_util.get_loss_fn(self, self.loss_spec) + self.to(self.device) + self.train() + +
[docs] def build_model_heads(self, in_dim): + '''Build each model_head. These are stored as Sequential models in model_heads''' + assert len(self.head_hid_layers) == len(in_dim), 'Hydra head hid_params inconsistent with number in dims' + model_heads = nn.ModuleList() + for in_d, hid_layers in zip(in_dim, self.head_hid_layers): + dims = [in_d] + hid_layers + model_head = net_util.build_fc_model(dims, self.hid_layers_activation) + model_heads.append(model_head) + return model_heads
+ +
[docs] def build_model_tails(self, out_dim, out_layer_activation): + '''Build each model_tail. These are stored as Sequential models in model_tails''' + if not ps.is_list(out_layer_activation): + out_layer_activation = [out_layer_activation] * len(out_dim) + model_tails = nn.ModuleList() + if ps.is_empty(self.tail_hid_layers): + for out_d, out_activ in zip(out_dim, out_layer_activation): + tail = net_util.build_fc_model([self.body_hid_layers[-1], out_d], out_activ) + model_tails.append(tail) + else: + assert len(self.tail_hid_layers) == len(out_dim), 'Hydra tail hid_params inconsistent with number out dims' + for out_d, out_activ, hid_layers in zip(out_dim, out_layer_activation, self.tail_hid_layers): + dims = hid_layers + model_tail = net_util.build_fc_model(dims, self.hid_layers_activation) + tail_out = net_util.build_fc_model([dims[-1], out_d], out_activ) + model_tail.add_module(str(len(model_tail)), tail_out) + model_tails.append(model_tail) + return model_tails
+ +
[docs] def forward(self, xs): + '''The feedforward step''' + head_xs = [] + for model_head, x in zip(self.model_heads, xs): + head_xs.append(model_head(x)) + head_xs = torch.cat(head_xs, dim=-1) + body_x = self.model_body(head_xs) + outs = [] + for model_tail in self.model_tails: + outs.append(model_tail(body_x)) + return outs
+ + +
[docs]class DuelingMLPNet(MLPNet): + ''' + Class for generating arbitrary sized feedforward neural network, with dueling heads. Intended for Q-Learning algorithms only. + Implementation based on "Dueling Network Architectures for Deep Reinforcement Learning" http://proceedings.mlr.press/v48/wangf16.pdf + + e.g. net_spec + "net": { + "type": "DuelingMLPNet", + "shared": true, + "hid_layers": [32], + "hid_layers_activation": "relu", + "init_fn": "xavier_uniform_", + "clip_grad_val": 1.0, + "loss_spec": { + "name": "MSELoss" + }, + "optim_spec": { + "name": "Adam", + "lr": 0.02 + }, + "lr_scheduler_spec": { + "name": "StepLR", + "step_size": 30, + "gamma": 0.1 + }, + "update_type": "replace", + "update_frequency": 1, + "polyak_coef": 0.9, + "gpu": true + } + ''' + + def __init__(self, net_spec, in_dim, out_dim): + nn.Module.__init__(self) + Net.__init__(self, net_spec, in_dim, out_dim) + # set default + util.set_attr(self, dict( + init_fn=None, + clip_grad_val=None, + loss_spec={'name': 'MSELoss'}, + optim_spec={'name': 'Adam'}, + lr_scheduler_spec=None, + update_type='replace', + update_frequency=1, + polyak_coef=0.0, + gpu=False, + )) + util.set_attr(self, self.net_spec, [ + 'shared', + 'hid_layers', + 'hid_layers_activation', + 'init_fn', + 'clip_grad_val', + 'loss_spec', + 'optim_spec', + 'lr_scheduler_spec', + 'update_type', + 'update_frequency', + 'polyak_coef', + 'gpu', + ]) + + # Guard against inappropriate algorithms and environments + # Build model body + dims = [self.in_dim] + self.hid_layers + self.model_body = net_util.build_fc_model(dims, self.hid_layers_activation) + # output layers + self.v = nn.Linear(dims[-1], 1) # state value + self.adv = nn.Linear(dims[-1], out_dim) # action dependent raw advantage + net_util.init_layers(self, self.init_fn) + self.loss_fn = net_util.get_loss_fn(self, self.loss_spec) + self.to(self.device) + +
[docs] def forward(self, x): + '''The feedforward step''' + x = self.model_body(x) + state_value = self.v(x) + raw_advantages = self.adv(x) + out = math_util.calc_q_value_logits(state_value, raw_advantages) + return out
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/net/net_util.html b/docs/build/html/_modules/convlab/agent/net/net_util.html new file mode 100644 index 0000000..fa33453 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/net/net_util.html @@ -0,0 +1,545 @@ + + + + + + + + + + + convlab.agent.net.net_util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.net.net_util

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import os
+from functools import partial, wraps
+
+import pydash as ps
+import torch
+import torch.nn as nn
+
+from convlab.lib import logger, optimizer, util
+
+logger = logger.get_logger(__name__)
+
+# register custom torch.optim
+setattr(torch.optim, 'GlobalAdam', optimizer.GlobalAdam)
+
+
+
[docs]class NoOpLRScheduler: + '''Symbolic LRScheduler class for API consistency''' + + def __init__(self, optim): + self.optim = optim + +
[docs] def step(self, epoch=None): + pass
+ +
[docs] def get_lr(self): + if hasattr(self.optim, 'defaults'): + return self.optim.defaults['lr'] + else: # TODO retrieve lr more generally + return self.optim.param_groups[0]['lr']
+ + +
[docs]def build_fc_model(dims, activation=None): + '''Build a full-connected model by interleaving nn.Linear and activation_fn''' + assert len(dims) >= 2, 'dims need to at least contain input, output' + # shift dims and make pairs of (in, out) dims per layer + dim_pairs = list(zip(dims[:-1], dims[1:])) + layers = [] + for in_d, out_d in dim_pairs: + layers.append(nn.Linear(in_d, out_d)) + if activation is not None: + layers.append(get_activation_fn(activation)) + model = nn.Sequential(*layers) + return model
+ + +
[docs]def get_nn_name(uncased_name): + '''Helper to get the proper name in PyTorch nn given a case-insensitive name''' + for nn_name in nn.__dict__: + if uncased_name.lower() == nn_name.lower(): + return nn_name + raise ValueError(f'Name {uncased_name} not found in {nn.__dict__}')
+ + +
[docs]def get_activation_fn(activation): + '''Helper to generate activation function layers for net''' + ActivationClass = getattr(nn, get_nn_name(activation)) + return ActivationClass()
+ + +
[docs]def get_loss_fn(cls, loss_spec): + '''Helper to parse loss param and construct loss_fn for net''' + LossClass = getattr(nn, get_nn_name(loss_spec['name'])) + loss_spec = ps.omit(loss_spec, 'name') + loss_fn = LossClass(**loss_spec) + return loss_fn
+ + +
[docs]def get_lr_scheduler(optim, lr_scheduler_spec): + '''Helper to parse lr_scheduler param and construct Pytorch optim.lr_scheduler''' + if ps.is_empty(lr_scheduler_spec): + lr_scheduler = NoOpLRScheduler(optim) + elif lr_scheduler_spec['name'] == 'LinearToZero': + LRSchedulerClass = getattr(torch.optim.lr_scheduler, 'LambdaLR') + frame = float(lr_scheduler_spec['frame']) + lr_scheduler = LRSchedulerClass(optim, lr_lambda=lambda x: 1 - x / frame) + else: + LRSchedulerClass = getattr(torch.optim.lr_scheduler, lr_scheduler_spec['name']) + lr_scheduler_spec = ps.omit(lr_scheduler_spec, 'name') + lr_scheduler = LRSchedulerClass(optim, **lr_scheduler_spec) + return lr_scheduler
+ + +
[docs]def get_optim(net, optim_spec): + '''Helper to parse optim param and construct optim for net''' + OptimClass = getattr(torch.optim, optim_spec['name']) + optim_spec = ps.omit(optim_spec, 'name') + optim = OptimClass(net.parameters(), **optim_spec) + return optim
+ + +
[docs]def get_policy_out_dim(body): + '''Helper method to construct the policy network out_dim for a body according to is_discrete, action_type''' + action_dim = body.action_dim + if body.is_discrete: + if body.action_type == 'multi_discrete': + assert ps.is_list(action_dim), action_dim + policy_out_dim = action_dim + else: + assert ps.is_integer(action_dim), action_dim + policy_out_dim = action_dim + else: + assert ps.is_integer(action_dim), action_dim + if action_dim == 1: # single action, use [loc, scale] + policy_out_dim = 2 + else: # multi-action, use [locs], [scales] + policy_out_dim = [action_dim, action_dim] + return policy_out_dim
+ + +
[docs]def get_out_dim(body, add_critic=False): + '''Construct the NetClass out_dim for a body according to is_discrete, action_type, and whether to add a critic unit''' + policy_out_dim = get_policy_out_dim(body) + if add_critic: + if ps.is_list(policy_out_dim): + out_dim = policy_out_dim + [1] + else: + out_dim = [policy_out_dim, 1] + else: + out_dim = policy_out_dim + return out_dim
+ + +
[docs]def init_layers(net, init_fn_name): + '''Primary method to initialize the weights of the layers of a network''' + if init_fn_name is None: + return + + # get nonlinearity + nonlinearity = get_nn_name(net.hid_layers_activation).lower() + if nonlinearity == 'leakyrelu': + nonlinearity = 'leaky_relu' # guard name + + # get init_fn and add arguments depending on nonlinearity + init_fn = getattr(nn.init, init_fn_name) + if 'kaiming' in init_fn_name: # has 'nonlinearity' as arg + assert nonlinearity in ['relu', 'leaky_relu'], f'Kaiming initialization not supported for {nonlinearity}' + init_fn = partial(init_fn, nonlinearity=nonlinearity) + elif 'orthogonal' in init_fn_name or 'xavier' in init_fn_name: # has 'gain' as arg + gain = nn.init.calculate_gain(nonlinearity) + init_fn = partial(init_fn, gain=gain) + else: + pass + + # finally, apply init_params to each layer in its modules + net.apply(partial(init_params, init_fn=init_fn))
+ + +
[docs]def init_params(module, init_fn): + '''Initialize module's weights using init_fn, and biases to 0.0''' + bias_init = 0.0 + classname = util.get_class_name(module) + if 'Net' in classname: # skip if it's a net, not pytorch layer + pass + elif any(k in classname for k in ('BatchNorm', 'Conv', 'Linear')): + init_fn(module.weight) + nn.init.constant_(module.bias, bias_init) + elif 'GRU' in classname: + for name, param in module.named_parameters(): + if 'weight' in name: + init_fn(param) + elif 'bias' in name: + nn.init.constant_(param, bias_init) + else: + pass
+ + +# params methods + + +
[docs]def save(net, model_path): + '''Save model weights to path''' + torch.save(net.state_dict(), util.smart_path(model_path))
+ + +
[docs]def save_algorithm(algorithm, ckpt=None): + '''Save all the nets for an algorithm''' + agent = algorithm.agent + net_names = algorithm.net_names + model_prepath = agent.spec['meta']['model_prepath'] + if ckpt is not None: + model_prepath = f'{model_prepath}_ckpt-{ckpt}' + for net_name in net_names: + net = getattr(algorithm, net_name) + model_path = f'{model_prepath}_{net_name}_model.pt' + save(net, model_path) + optim_name = net_name.replace('net', 'optim') + optim = getattr(algorithm, optim_name, None) + if optim is not None: # only trainable net has optim + optim_path = f'{model_prepath}_{net_name}_optim.pt' + save(optim, optim_path) + logger.debug(f'Saved algorithm {util.get_class_name(algorithm)} nets {net_names} to {model_prepath}_*.pt')
+ + +
[docs]def load(net, model_path): + '''Save model weights from a path into a net module''' + device = None if torch.cuda.is_available() else 'cpu' + net.load_state_dict(torch.load(util.smart_path(model_path), map_location=device))
+ + +
[docs]def load_algorithm(algorithm): + '''Save all the nets for an algorithm''' + agent = algorithm.agent + net_names = algorithm.net_names + if util.in_eval_lab_modes(): + # load specific model in eval mode + model_prepath = agent.spec['meta']['eval_model_prepath'] + else: + model_prepath = agent.spec['meta']['model_prepath'] + logger.info(f'Loading algorithm {util.get_class_name(algorithm)} nets {net_names} from {model_prepath}_*.pt') + for net_name in net_names: + net = getattr(algorithm, net_name) + model_path = f'{model_prepath}_{net_name}_model.pt' + load(net, model_path) + optim_name = net_name.replace('net', 'optim') + optim = getattr(algorithm, optim_name, None) + if optim is not None: # only trainable net has optim + optim_path = f'{model_prepath}_{net_name}_optim.pt' + load(optim, optim_path)
+ + +
[docs]def copy(src_net, tar_net): + '''Copy model weights from src to target''' + tar_net.load_state_dict(src_net.state_dict())
+ + +
[docs]def polyak_update(src_net, tar_net, old_ratio=0.5): + ''' + Polyak weight update to update a target tar_net, retain old weights by its ratio, i.e. + target <- old_ratio * source + (1 - old_ratio) * target + ''' + for src_param, tar_param in zip(src_net.parameters(), tar_net.parameters()): + tar_param.data.copy_(old_ratio * src_param.data + (1.0 - old_ratio) * tar_param.data)
+ + +
[docs]def to_check_train_step(): + '''Condition for running assert_trained''' + return os.environ.get('PY_ENV') == 'test' or util.get_lab_mode() == 'dev'
+ + +
[docs]def dev_check_train_step(fn): + ''' + Decorator to check if net.train_step actually updates the network weights properly + Triggers only if to_check_train_step is True (dev/test mode) + @example + + @net_util.dev_check_train_step + def train_step(self, ...): + ... + ''' + @wraps(fn) + def check_fn(*args, **kwargs): + if not to_check_train_step(): + return fn(*args, **kwargs) + + net = args[0] # first arg self + # get pre-update parameters to compare + pre_params = [param.clone() for param in net.parameters()] + + # run train_step, get loss + loss = fn(*args, **kwargs) + assert not torch.isnan(loss).any(), loss + + # get post-update parameters to compare + post_params = [param.clone() for param in net.parameters()] + if loss == 0.0: + # if loss is 0, there should be no updates + # TODO if without momentum, parameters should not change too + for p_name, param in net.named_parameters(): + assert param.grad.norm() == 0 + else: + # check parameter updates + try: + assert not all(torch.equal(w1, w2) for w1, w2 in zip(pre_params, post_params)), f'Model parameter is not updated in train_step(), check if your tensor is detached from graph. Loss: {loss:g}' + logger.info(f'Model parameter is updated in train_step(). Loss: {loss: g}') + except Exception as e: + logger.error(e) + if os.environ.get('PY_ENV') == 'test': + # raise error if in unit test + raise(e) + + # check grad norms + min_norm, max_norm = 0.0, 1e5 + for p_name, param in net.named_parameters(): + try: + grad_norm = param.grad.norm() + assert min_norm < grad_norm < max_norm, f'Gradient norm for {p_name} is {grad_norm:g}, fails the extreme value check {min_norm} < grad_norm < {max_norm}. Loss: {loss:g}. Check your network and loss computation.' + except Exception as e: + logger.warning(e) + logger.info(f'Gradient norms passed value check.') + logger.debug('Passed network parameter update check.') + # store grad norms for debugging + net.store_grad_norms() + return loss + return check_fn
+ + +
[docs]def get_grad_norms(algorithm): + '''Gather all the net's grad norms of an algorithm for debugging''' + grad_norms = [] + for net_name in algorithm.net_names: + net = getattr(algorithm, net_name) + if net.grad_norms is not None: + grad_norms.extend(net.grad_norms) + return grad_norms
+ + +
[docs]def init_global_nets(algorithm): + ''' + Initialize global_nets for Hogwild using an identical instance of an algorithm from an isolated Session + in spec.meta.distributed, specify either: + - 'shared': global network parameter is shared all the time. In this mode, algorithm local network will be replaced directly by global_net via overriding by identify attribute name + - 'synced': global network parameter is periodically synced to local network after each gradient push. In this mode, algorithm will keep a separate reference to `global_{net}` for each of its network + ''' + dist_mode = algorithm.agent.spec['meta']['distributed'] + assert dist_mode in ('shared', 'synced'), f'Unrecognized distributed mode' + global_nets = {} + for net_name in algorithm.net_names: + optim_name = net_name.replace('net', 'optim') + if not hasattr(algorithm, optim_name): # only for trainable network, i.e. has an optim + continue + g_net = getattr(algorithm, net_name) + g_net.share_memory() # make net global + if dist_mode == 'shared': # use the same name to override the local net + global_nets[net_name] = g_net + else: # keep a separate reference for syncing + global_nets[f'global_{net_name}'] = g_net + # if optim is Global, set to override the local optim and its scheduler + optim = getattr(algorithm, optim_name) + if 'Global' in util.get_class_name(optim): + optim.share_memory() # make optim global + global_nets[optim_name] = optim + lr_scheduler_name = net_name.replace('net', 'lr_scheduler') + lr_scheduler = getattr(algorithm, lr_scheduler_name) + global_nets[lr_scheduler_name] = lr_scheduler + logger.info(f'Initialized global_nets attr {list(global_nets.keys())} for Hogwild') + return global_nets
+ + +
[docs]def set_global_nets(algorithm, global_nets): + '''For Hogwild, set attr built in init_global_nets above. Use in algorithm init.''' + # set attr first so algorithm always has self.global_{net} to pass into train_step + for net_name in algorithm.net_names: + setattr(algorithm, f'global_{net_name}', None) + # set attr created in init_global_nets + if global_nets is not None: + util.set_attr(algorithm, global_nets) + logger.info(f'Set global_nets attr {list(global_nets.keys())} for Hogwild')
+ + +
[docs]def push_global_grads(net, global_net): + '''Push gradients to global_net, call inside train_step between loss.backward() and optim.step()''' + for param, global_param in zip(net.parameters(), global_net.parameters()): + if global_param.grad is not None: + return # quick skip + global_param._grad = param.grad
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/agent/net/recurrent.html b/docs/build/html/_modules/convlab/agent/net/recurrent.html new file mode 100644 index 0000000..b346045 --- /dev/null +++ b/docs/build/html/_modules/convlab/agent/net/recurrent.html @@ -0,0 +1,357 @@ + + + + + + + + + + + convlab.agent.net.recurrent — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.agent.net.recurrent

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import pydash as ps
+import torch.nn as nn
+
+from convlab.agent.net import net_util
+from convlab.agent.net.base import Net
+from convlab.lib import util
+
+
+
[docs]class RecurrentNet(Net, nn.Module): + ''' + Class for generating arbitrary sized recurrent neural networks which take a sequence of states as input. + + Assumes that a single input example is organized into a 3D tensor + batch_size x seq_len x state_dim + The entire model consists of three parts: + 1. self.fc_model (state processing) + 2. self.rnn_model + 3. self.model_tails + + e.g. net_spec + "net": { + "type": "RecurrentNet", + "shared": true, + "cell_type": "GRU", + "fc_hid_layers": [], + "hid_layers_activation": "relu", + "out_layer_activation": null, + "rnn_hidden_size": 32, + "rnn_num_layers": 1, + "bidirectional": False, + "seq_len": 4, + "init_fn": "xavier_uniform_", + "clip_grad_val": 1.0, + "loss_spec": { + "name": "MSELoss" + }, + "optim_spec": { + "name": "Adam", + "lr": 0.01 + }, + "lr_scheduler_spec": { + "name": "StepLR", + "step_size": 30, + "gamma": 0.1 + }, + "update_type": "replace", + "update_frequency": 1, + "polyak_coef": 0.9, + "gpu": true + } + ''' + + def __init__(self, net_spec, in_dim, out_dim): + ''' + net_spec: + cell_type: any of RNN, LSTM, GRU + fc_hid_layers: list of fc layers preceeding the RNN layers + hid_layers_activation: activation function for the fc hidden layers + out_layer_activation: activation function for the output layer, same shape as out_dim + rnn_hidden_size: rnn hidden_size + rnn_num_layers: number of recurrent layers + bidirectional: if RNN should be bidirectional + seq_len: length of the history of being passed to the net + init_fn: weight initialization function + clip_grad_val: clip gradient norm if value is not None + loss_spec: measure of error between model predictions and correct outputs + optim_spec: parameters for initializing the optimizer + lr_scheduler_spec: Pytorch optim.lr_scheduler + update_type: method to update network weights: 'replace' or 'polyak' + update_frequency: how many total timesteps per update + polyak_coef: ratio of polyak weight update + gpu: whether to train using a GPU. Note this will only work if a GPU is available, othewise setting gpu=True does nothing + ''' + nn.Module.__init__(self) + super().__init__(net_spec, in_dim, out_dim) + # set default + util.set_attr(self, dict( + out_layer_activation=None, + cell_type='GRU', + rnn_num_layers=1, + bidirectional=False, + init_fn=None, + clip_grad_val=None, + loss_spec={'name': 'MSELoss'}, + optim_spec={'name': 'Adam'}, + lr_scheduler_spec=None, + update_type='replace', + update_frequency=1, + polyak_coef=0.0, + gpu=False, + )) + util.set_attr(self, self.net_spec, [ + 'cell_type', + 'fc_hid_layers', + 'hid_layers_activation', + 'out_layer_activation', + 'rnn_hidden_size', + 'rnn_num_layers', + 'bidirectional', + 'seq_len', + 'init_fn', + 'clip_grad_val', + 'loss_spec', + 'optim_spec', + 'lr_scheduler_spec', + 'update_type', + 'update_frequency', + 'polyak_coef', + 'gpu', + ]) + # restore proper in_dim from env stacked state_dim (stack_len, *raw_state_dim) + self.in_dim = in_dim[1:] if len(in_dim) > 2 else in_dim[1] + # fc body: state processing model + if ps.is_empty(self.fc_hid_layers): + self.rnn_input_dim = self.in_dim + else: + fc_dims = [self.in_dim] + self.fc_hid_layers + self.fc_model = net_util.build_fc_model(fc_dims, self.hid_layers_activation) + self.rnn_input_dim = fc_dims[-1] + + # RNN model + self.rnn_model = getattr(nn, net_util.get_nn_name(self.cell_type))( + input_size=self.rnn_input_dim, + hidden_size=self.rnn_hidden_size, + num_layers=self.rnn_num_layers, + batch_first=True, bidirectional=self.bidirectional) + + # tails. avoid list for single-tail for compute speed + if ps.is_integer(self.out_dim): + self.model_tail = net_util.build_fc_model([self.rnn_hidden_size, self.out_dim], self.out_layer_activation) + else: + if not ps.is_list(self.out_layer_activation): + self.out_layer_activation = [self.out_layer_activation] * len(out_dim) + assert len(self.out_layer_activation) == len(self.out_dim) + tails = [] + for out_d, out_activ in zip(self.out_dim, self.out_layer_activation): + tail = net_util.build_fc_model([self.rnn_hidden_size, out_d], out_activ) + tails.append(tail) + self.model_tails = nn.ModuleList(tails) + + net_util.init_layers(self, self.init_fn) + self.loss_fn = net_util.get_loss_fn(self, self.loss_spec) + self.to(self.device) + self.train() + +
[docs] def forward(self, x): + '''The feedforward step. Input is batch_size x seq_len x state_dim''' + # Unstack input to (batch_size x seq_len) x state_dim in order to transform all state inputs + batch_size = x.size(0) + x = x.view(-1, self.in_dim) + if hasattr(self, 'fc_model'): + x = self.fc_model(x) + # Restack to batch_size x seq_len x rnn_input_dim + x = x.view(-1, self.seq_len, self.rnn_input_dim) + if self.cell_type == 'LSTM': + _output, (h_n, c_n) = self.rnn_model(x) + else: + _output, h_n = self.rnn_model(x) + hid_x = h_n[-1] # get final time-layer + # return tensor if single tail, else list of tail tensors + if hasattr(self, 'model_tails'): + outs = [] + for model_tail in self.model_tails: + outs.append(model_tail(hid_x)) + return outs + else: + return self.model_tail(hid_x)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/env.html b/docs/build/html/_modules/convlab/env.html new file mode 100644 index 0000000..42862e9 --- /dev/null +++ b/docs/build/html/_modules/convlab/env.html @@ -0,0 +1,271 @@ + + + + + + + + + + + convlab.env — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.env

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+'''
+The environment module
+Contains graduated components from experiments for building/using environment.
+Provides the rich experience for agent embodiment, reflects the curriculum and allows teaching (possibly allows teacher to enter).
+To be designed by human and evolution module, based on the curriculum and fitness metrics.
+'''
+import pydash as ps
+
+from convlab.env.base import Clock, ENV_DATA_NAMES
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]def make_env(spec, e=None): + if spec['env'][0]['name'] == 'movie': + from convlab.env.movie import MovieEnv + env = MovieEnv(spec, e) + elif spec['env'][0]['name'] == 'multiwoz': + from convlab.env.multiwoz import MultiWozEnv + env = MultiWozEnv(spec, e) + + return env
+ + +
[docs]class EnvSpace: + ''' + Subspace of AEBSpace, collection of all envs, with interface to Session logic; same methods as singleton envs. + Access AgentSpace properties by: AgentSpace - AEBSpace - EnvSpace - Envs + ''' + + def __init__(self, spec, aeb_space): + self.spec = spec + self.aeb_space = aeb_space + aeb_space.env_space = self + self.info_space = aeb_space.info_space + self.envs = [] + for e in range(len(self.spec['env'])): + env = make_env(self.spec, e, env_space=self) + self.envs.append(env) + logger.info(util.self_desc(self)) + +
[docs] def get(self, e): + return self.envs[e]
+ +
[docs] def get_base_clock(self): + '''Get the clock with the finest time unit, i.e. ticks the most cycles in a given time, or the highest clock_speed''' + fastest_env = ps.max_by(self.envs, lambda env: env.clock_speed) + clock = fastest_env.clock + return clock
+ +
[docs] @lab_api + def reset(self): + logger.debug3('EnvSpace.reset') + _reward_v, state_v, done_v = self.aeb_space.init_data_v(ENV_DATA_NAMES) + for env in self.envs: + _reward_e, state_e, done_e = env.space_reset() + state_v[env.e, 0:len(state_e)] = state_e + done_v[env.e, 0:len(done_e)] = done_e + _reward_space, state_space, done_space = self.aeb_space.add(ENV_DATA_NAMES, (_reward_v, state_v, done_v)) + logger.debug3(f'\nstate_space: {state_space}') + return _reward_space, state_space, done_space
+ +
[docs] @lab_api + def step(self, action_space): + reward_v, state_v, done_v = self.aeb_space.init_data_v(ENV_DATA_NAMES) + for env in self.envs: + e = env.e + action_e = action_space.get(e=e) + reward_e, state_e, done_e = env.space_step(action_e) + reward_v[e, 0:len(reward_e)] = reward_e + state_v[e, 0:len(state_e)] = state_e + done_v[e, 0:len(done_e)] = done_e + reward_space, state_space, done_space = self.aeb_space.add(ENV_DATA_NAMES, (reward_v, state_v, done_v)) + logger.debug3(f'\nreward_space: {reward_space}\nstate_space: {state_space}\ndone_space: {done_space}') + return reward_space, state_space, done_space
+ +
[docs] @lab_api + def close(self): + logger.info('EnvSpace.close') + for env in self.envs: + env.close()
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/env/base.html b/docs/build/html/_modules/convlab/env/base.html new file mode 100644 index 0000000..52cd671 --- /dev/null +++ b/docs/build/html/_modules/convlab/env/base.html @@ -0,0 +1,380 @@ + + + + + + + + + + + convlab.env.base — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.env.base

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import time
+from abc import ABC, abstractmethod
+
+import numpy as np
+import pydash as ps
+from gym import spaces
+
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+
+ENV_DATA_NAMES = ['reward', 'state', 'done']
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]def set_gym_space_attr(gym_space): + '''Set missing gym space attributes for standardization''' + if isinstance(gym_space, spaces.Box): + setattr(gym_space, 'is_discrete', False) + elif isinstance(gym_space, spaces.Discrete): + setattr(gym_space, 'is_discrete', True) + setattr(gym_space, 'low', 0) + setattr(gym_space, 'high', gym_space.n) + elif isinstance(gym_space, spaces.MultiBinary): + setattr(gym_space, 'is_discrete', True) + setattr(gym_space, 'low', np.full(gym_space.n, 0)) + setattr(gym_space, 'high', np.full(gym_space.n, 2)) + elif isinstance(gym_space, spaces.MultiDiscrete): + setattr(gym_space, 'is_discrete', True) + setattr(gym_space, 'low', np.zeros_like(gym_space.nvec)) + setattr(gym_space, 'high', np.array(gym_space.nvec)) + else: + raise ValueError('gym_space not recognized')
+ + +
[docs]class Clock: + '''Clock class for each env and space to keep track of relative time. Ticking and control loop is such that reset is at t=0 and epi=0''' + + def __init__(self, max_frame=int(1e7), clock_speed=1): + self.max_frame = max_frame + self.clock_speed = int(clock_speed) + self.reset() + +
[docs] def reset(self): + self.t = 0 + self.frame = 0 # i.e. total_t + self.epi = 0 + self.start_wall_t = time.time() + self.batch_size = 1 # multiplier to accurately count opt steps + self.opt_step = 0 # count the number of optimizer updates
+ +
[docs] def get(self, unit='frame'): + return getattr(self, unit)
+ +
[docs] def get_elapsed_wall_t(self): + '''Calculate the elapsed wall time (int seconds) since self.start_wall_t''' + return int(time.time() - self.start_wall_t)
+ +
[docs] def set_batch_size(self, batch_size): + self.batch_size = batch_size
+ +
[docs] def tick(self, unit='t'): + if unit == 't': # timestep + self.t += self.clock_speed + self.frame += self.clock_speed + elif unit == 'epi': # episode, reset timestep + self.epi += 1 + self.t = 0 + elif unit == 'opt_step': + self.opt_step += self.batch_size + else: + raise KeyError
+ + +
[docs]class BaseEnv(ABC): + ''' + The base Env class with API and helper methods. Use this to implement your env class that is compatible with the Lab APIs + + e.g. env_spec + "env": [{ + "name": "PongNoFrameskip-v4", + "frame_op": "concat", + "frame_op_len": 4, + "normalize_state": false, + "reward_scale": "sign", + "num_envs": 8, + "max_t": null, + "max_frame": 1e7 + }], + ''' + + def __init__(self, spec, e=None): + self.e = e or 0 # for multi-env + self.done = False + self.env_spec = spec['env'][self.e] + # set default + util.set_attr(self, dict( + log_frequency=None, # default to log at epi done + frame_op=None, + frame_op_len=None, + normalize_state=False, + reward_scale=None, + num_envs=None, + )) + util.set_attr(self, spec['meta'], [ + 'log_frequency', + 'eval_frequency', + ]) + util.set_attr(self, self.env_spec, [ + 'name', + 'frame_op', + 'frame_op_len', + 'normalize_state', + 'reward_scale', + 'num_envs', + 'max_t', + 'max_frame', + ]) + seq_len = ps.get(spec, 'agent.0.net.seq_len') + if seq_len is not None: # infer if using RNN + self.frame_op = 'stack' + self.frame_op_len = seq_len + if util.in_eval_lab_modes(): # use singleton for eval + self.num_envs = 1 + self.log_frequency = None + if spec['meta']['distributed'] != False: # divide max_frame for distributed + self.max_frame = int(self.max_frame / spec['meta']['max_session']) + self.is_venv = (self.num_envs is not None and self.num_envs > 1) + if self.is_venv: + assert self.log_frequency is not None, f'Specify log_frequency when using venv' + self.clock_speed = 1 * (self.num_envs or 1) # tick with a multiple of num_envs to properly count frames + self.clock = Clock(self.max_frame, self.clock_speed) + self.to_render = util.to_render() + + def _set_attr_from_u_env(self, u_env): + '''Set the observation, action dimensions and action type from u_env''' + self.observation_space, self.action_space = self._get_spaces(u_env) + self.observable_dim = self._get_observable_dim(self.observation_space) + self.action_dim = self._get_action_dim(self.action_space) + self.is_discrete = self._is_discrete(self.action_space) + + def _get_spaces(self, u_env): + '''Helper to set the extra attributes to, and get, observation and action spaces''' + observation_space = u_env.observation_space + action_space = u_env.action_space + set_gym_space_attr(observation_space) + set_gym_space_attr(action_space) + return observation_space, action_space + + def _get_observable_dim(self, observation_space): + '''Get the observable dim for an agent in env''' + state_dim = observation_space.shape + if len(state_dim) == 1: + state_dim = state_dim[0] + return {'state': state_dim} + + def _get_action_dim(self, action_space): + '''Get the action dim for an action_space for agent to use''' + if isinstance(action_space, spaces.Box): + assert len(action_space.shape) == 1 + action_dim = action_space.shape[0] + elif isinstance(action_space, (spaces.Discrete, spaces.MultiBinary)): + action_dim = action_space.n + elif isinstance(action_space, spaces.MultiDiscrete): + action_dim = action_space.nvec.tolist() + else: + raise ValueError('action_space not recognized') + return action_dim + + def _is_discrete(self, action_space): + '''Check if an action space is discrete''' + return util.get_class_name(action_space) != 'Box' + +
[docs] @abstractmethod + @lab_api + def reset(self): + '''Reset method, return state''' + raise NotImplementedError
+ +
[docs] @abstractmethod + @lab_api + def step(self, action): + '''Step method, return state, reward, done, info''' + raise NotImplementedError
+ +
[docs] @abstractmethod + @lab_api + def close(self): + '''Method to close and cleanup env''' + raise NotImplementedError
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/env/movie.html b/docs/build/html/_modules/convlab/env/movie.html new file mode 100644 index 0000000..dcc756d --- /dev/null +++ b/docs/build/html/_modules/convlab/env/movie.html @@ -0,0 +1,1532 @@ + + + + + + + + + + + convlab.env.movie — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.env.movie

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import os
+import pickle
+import random
+from collections import defaultdict
+from copy import deepcopy
+
+import numpy as np
+import pydash as ps
+from gym import spaces
+
+from convlab.env.base import BaseEnv, ENV_DATA_NAMES, set_gym_space_attr
+# from convlab.env.registration import get_env_path
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+
+logger = logger.get_logger(__name__)
+
+
+################################################################################
+#   Parameters for Agents
+################################################################################
+agent_params = {}
+agent_params['max_turn'] = 40 
+agent_params['agent_run_mode'] = 1 
+agent_params['agent_act_level'] = 0 
+
+
+################################################################################
+#   Parameters for User Simulators
+################################################################################
+usersim_params = {}
+usersim_params['max_turn'] = 40 
+usersim_params['slot_err_probability'] = 0
+usersim_params['slot_err_mode'] = 0 
+usersim_params['intent_err_probability'] = 0 
+usersim_params['simulator_run_mode'] = 1 
+usersim_params['simulator_act_level'] = 0
+usersim_params['learning_phase'] = 'all' 
+
+DATAPATH=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "data/movie")
+
+dict_path = os.path.join(DATAPATH, 'dicts.v3.p') 
+goal_file_path = os.path.join(DATAPATH, 'user_goals_first_turn_template.part.movie.v1.p')
+
+# load the user goals from .p file
+all_goal_set = pickle.load(open(goal_file_path, 'rb'))
+
+# split goal set
+split_fold = 5
+goal_set = {'train':[], 'valid':[], 'test':[], 'all':[]}
+for u_goal_id, u_goal in enumerate(all_goal_set):
+    if u_goal_id % split_fold == 1: goal_set['test'].append(u_goal)
+    else: goal_set['train'].append(u_goal)
+    goal_set['all'].append(u_goal)
+# end split goal set
+
+movie_kb_path = os.path.join(DATAPATH, 'movie_kb.1k.p')
+# movie_kb = pickle.load(open(movie_kb_path, 'rb'), encoding='latin1')
+movie_dictionary = pickle.load(open(movie_kb_path, 'rb'), encoding='latin1')
+
+
[docs]def text_to_dict(path): + """ Read in a text file as a dictionary where keys are text and values are indices (line numbers) """ + + slot_set = {} + with open(path, 'r') as f: + index = 0 + for line in f.readlines(): + slot_set[line.strip('\n').strip('\r')] = index + index += 1 + return slot_set
+ +act_set = text_to_dict(os.path.join(DATAPATH, 'dia_acts.txt')) +slot_set = text_to_dict(os.path.join(DATAPATH, 'slot_set.txt')) + +################################################################################ +# a movie dictionary for user simulator - slot:possible values +################################################################################ +# movie_dictionary = pickle.load(open(dict_path, 'rb')) + +sys_request_slots = ['moviename', 'theater', 'starttime', 'date', 'numberofpeople', 'genre', 'state', 'city', 'zip', 'critic_rating', 'mpaa_rating', 'distanceconstraints', 'video_format', 'theater_chain', 'price', 'actor', 'description', 'other', 'numberofkids'] +sys_inform_slots = ['moviename', 'theater', 'starttime', 'date', 'genre', 'state', 'city', 'zip', 'critic_rating', 'mpaa_rating', 'distanceconstraints', 'video_format', 'theater_chain', 'price', 'actor', 'description', 'other', 'numberofkids', 'taskcomplete', 'ticket'] + +start_dia_acts = { + #'greeting':[], + 'request':['moviename', 'starttime', 'theater', 'city', 'state', 'date', 'genre', 'ticket', 'numberofpeople'] +} + +################################################################################ +# Dialog status +################################################################################ +FAILED_DIALOG = -1 +SUCCESS_DIALOG = 1 +NO_OUTCOME_YET = 0 + +# Rewards +SUCCESS_REWARD = 50 +FAILURE_REWARD = 0 +PER_TURN_REWARD = 0 + +################################################################################ +# Special Slot Values +################################################################################ +I_DO_NOT_CARE = "I do not care" +NO_VALUE_MATCH = "NO VALUE MATCHES!!!" +TICKET_AVAILABLE = 'Ticket Available' + +################################################################################ +# Constraint Check +################################################################################ +CONSTRAINT_CHECK_FAILURE = 0 +CONSTRAINT_CHECK_SUCCESS = 1 + +################################################################################ +# NLG Beam Search +################################################################################ +nlg_beam_size = 10 + +################################################################################ +# run_mode: 0 for dia-act; 1 for NL; 2 for no output +################################################################################ +run_mode = 3 +auto_suggest = 0 + +################################################################################ +# A Basic Set of Feasible actions to be Consdered By an RL agent +################################################################################ +feasible_actions = [ + ############################################################################ + # greeting actions + ############################################################################ + #{'diaact':"greeting", 'inform_slots':{}, 'request_slots':{}}, + ############################################################################ + # confirm_question actions + ############################################################################ + {'diaact':"confirm_question", 'inform_slots':{}, 'request_slots':{}}, + ############################################################################ + # confirm_answer actions + ############################################################################ + {'diaact':"confirm_answer", 'inform_slots':{}, 'request_slots':{}}, + ############################################################################ + # thanks actions + ############################################################################ + {'diaact':"thanks", 'inform_slots':{}, 'request_slots':{}}, + ############################################################################ + # deny actions + ############################################################################ + {'diaact':"deny", 'inform_slots':{}, 'request_slots':{}}, +] +############################################################################ +# Adding the inform actions +############################################################################ +for slot in sys_inform_slots: + feasible_actions.append({'diaact':'inform', 'inform_slots':{slot:"PLACEHOLDER"}, 'request_slots':{}}) + +############################################################################ +# Adding the request actions +############################################################################ +for slot in sys_request_slots: + feasible_actions.append({'diaact':'request', 'inform_slots':{}, 'request_slots': {slot: "UNK"}}) + + +
[docs]class UserSimulator: + """ Parent class for all user sims to inherit from """ + + def __init__(self, movie_dict=None, act_set=None, slot_set=None, start_set=None, params=None): + """ Constructor shared by all user simulators """ + + self.movie_dict = movie_dict + self.act_set = act_set + self.slot_set = slot_set + self.start_set = start_set + + self.max_turn = usersim_params['max_turn'] + self.slot_err_probability = usersim_params['slot_err_probability'] + self.slot_err_mode = usersim_params['slot_err_mode'] + self.intent_err_probability = usersim_params['intent_err_probability'] + + +
[docs] def initialize_episode(self): + """ Initialize a new episode (dialog)""" + + print("initialize episode called, generating goal") + self.goal = random.choice(self.start_set) + self.goal['request_slots']['ticket'] = 'UNK' + episode_over, user_action = self._sample_action() + assert (episode_over != 1),' but we just started' + return user_action
+ + +
[docs] def next(self, system_action): + pass
+ + + +
[docs] def set_nlg_model(self, nlg_model): + self.nlg_model = nlg_model
+ +
[docs] def set_nlu_model(self, nlu_model): + self.nlu_model = nlu_model
+ + + +
[docs] def add_nl_to_action(self, user_action): + """ Add NL to User Dia_Act """ + + user_nlg_sentence = self.nlg_model.convert_diaact_to_nl(user_action, 'usr') + user_action['nl'] = user_nlg_sentence + + if self.simulator_act_level == 1: + user_nlu_res = self.nlu_model.generate_dia_act(user_action['nl']) # NLU + if user_nlu_res != None: + #user_nlu_res['diaact'] = user_action['diaact'] # or not? + user_action.update(user_nlu_res)
+ + + +
[docs]class RuleSimulator(UserSimulator): + """ A rule-based user simulator for testing dialog policy """ + + def __init__(self, movie_dict=None, act_set=None, slot_set=None, start_set=None, params=None): + """ Constructor shared by all user simulators """ + + self.movie_dict = movie_dict + self.act_set = act_set + self.slot_set = slot_set + self.start_set = start_set + + self.max_turn = usersim_params['max_turn'] + self.slot_err_probability = usersim_params['slot_err_probability'] + self.slot_err_mode = usersim_params['slot_err_mode'] + self.intent_err_probability = usersim_params['intent_err_probability'] + + self.simulator_run_mode = usersim_params['simulator_run_mode'] + self.simulator_act_level = usersim_params['simulator_act_level'] + + self.learning_phase = usersim_params['learning_phase'] + +
[docs] def initialize_episode(self): + """ Initialize a new episode (dialog) + state['history_slots']: keeps all the informed_slots + state['rest_slots']: keep all the slots (which is still in the stack yet) + """ + + self.state = {} + self.state['history_slots'] = {} + self.state['inform_slots'] = {} + self.state['request_slots'] = {} + self.state['rest_slots'] = [] + self.state['turn'] = 0 + + self.episode_over = False + self.dialog_status = NO_OUTCOME_YET + + #self.goal = random.choice(self.start_set) + self.goal = self._sample_goal(self.start_set) + self.goal['request_slots']['ticket'] = 'UNK' + self.constraint_check = CONSTRAINT_CHECK_FAILURE + + """ Debug: build a fake goal mannually """ + #self.debug_falk_goal() + + # sample first action + user_action = self._sample_action() + assert (self.episode_over != 1),' but we just started' + return user_action
+ + def _sample_action(self): + """ randomly sample a start action based on user goal """ + + self.state['diaact'] = random.choice(list(start_dia_acts.keys())) + + # "sample" informed slots + if len(self.goal['inform_slots']) > 0: + known_slot = random.choice(list(self.goal['inform_slots'].keys())) + self.state['inform_slots'][known_slot] = self.goal['inform_slots'][known_slot] + + if 'moviename' in self.goal['inform_slots']: # 'moviename' must appear in the first user turn + self.state['inform_slots']['moviename'] = self.goal['inform_slots']['moviename'] + + for slot in self.goal['inform_slots'].keys(): + if known_slot == slot or slot == 'moviename': continue + self.state['rest_slots'].append(slot) + + self.state['rest_slots'].extend(self.goal['request_slots'].keys()) + + # "sample" a requested slot + request_slot_set = list(self.goal['request_slots'].keys()) + request_slot_set.remove('ticket') + if len(request_slot_set) > 0: + request_slot = random.choice(request_slot_set) + else: + request_slot = 'ticket' + self.state['request_slots'][request_slot] = 'UNK' + + if len(self.state['request_slots']) == 0: + self.state['diaact'] = 'inform' + + if (self.state['diaact'] in ['thanks','closing']): self.episode_over = True #episode_over = True + else: self.episode_over = False #episode_over = False + + sample_action = {} + sample_action['diaact'] = self.state['diaact'] + sample_action['inform_slots'] = self.state['inform_slots'] + sample_action['request_slots'] = self.state['request_slots'] + sample_action['turn'] = self.state['turn'] + + # self.add_nl_to_action(sample_action) + return sample_action + + def _sample_goal(self, goal_set): + """ sample a user goal """ + + sample_goal = random.choice(self.start_set[self.learning_phase]) + return sample_goal + + +
[docs] def corrupt(self, user_action): + """ Randomly corrupt an action with error probs (slot_err_probability and slot_err_mode) on Slot and Intent (intent_err_probability). """ + + for slot in user_action['inform_slots'].keys(): + slot_err_prob_sample = random.random() + if slot_err_prob_sample < self.slot_err_probability: # add noise for slot level + if self.slot_err_mode == 0: # replace the slot_value only + if slot in self.movie_dict.keys(): user_action['inform_slots'][slot] = random.choice(self.movie_dict[slot]) + elif self.slot_err_mode == 1: # combined + slot_err_random = random.random() + if slot_err_random <= 0.33: + if slot in self.movie_dict.keys(): user_action['inform_slots'][slot] = random.choice(self.movie_dict[slot]) + elif slot_err_random > 0.33 and slot_err_random <= 0.66: + del user_action['inform_slots'][slot] + random_slot = random.choice(self.movie_dict.keys()) + user_action[random_slot] = random.choice(self.movie_dict[random_slot]) + else: + del user_action['inform_slots'][slot] + elif self.slot_err_mode == 2: #replace slot and its values + del user_action['inform_slots'][slot] + random_slot = random.choice(self.movie_dict.keys()) + user_action[random_slot] = random.choice(self.movie_dict[random_slot]) + elif self.slot_err_mode == 3: # delete the slot + del user_action['inform_slots'][slot] + + intent_err_sample = random.random() + if intent_err_sample < self.intent_err_probability: # add noise for intent level + user_action['diaact'] = random.choice(self.act_set.keys())
+ +
[docs] def debug_falk_goal(self): + """ Debug function: build a fake goal mannually (Can be moved in future) """ + + self.goal['inform_slots'].clear() + #self.goal['inform_slots']['city'] = 'seattle' + self.goal['inform_slots']['numberofpeople'] = '2' + #self.goal['inform_slots']['theater'] = 'amc pacific place 11 theater' + #self.goal['inform_slots']['starttime'] = '10:00 pm' + #self.goal['inform_slots']['date'] = 'tomorrow' + self.goal['inform_slots']['moviename'] = 'zoology' + self.goal['inform_slots']['distanceconstraints'] = 'close to 95833' + self.goal['request_slots'].clear() + self.goal['request_slots']['ticket'] = 'UNK' + self.goal['request_slots']['theater'] = 'UNK' + self.goal['request_slots']['starttime'] = 'UNK' + self.goal['request_slots']['date'] = 'UNK'
+ +
[docs] def next(self, system_action): + """ Generate next User Action based on last System Action """ + + self.state['turn'] += 2 + self.episode_over = False + self.dialog_status = NO_OUTCOME_YET + + sys_act = system_action['diaact'] + + if (self.max_turn > 0 and self.state['turn'] > self.max_turn): + self.dialog_status = FAILED_DIALOG + self.episode_over = True + self.state['diaact'] = "closing" + else: + self.state['history_slots'].update(self.state['inform_slots']) + self.state['inform_slots'].clear() + + if sys_act == "inform": + self.response_inform(system_action) + elif sys_act == "multiple_choice": + self.response_multiple_choice(system_action) + elif sys_act == "request": + self.response_request(system_action) + elif sys_act == "thanks": + self.response_thanks(system_action) + elif sys_act == "confirm_answer": + self.response_confirm_answer(system_action) + elif sys_act == "closing": + self.episode_over = True + self.state['diaact'] = "thanks" + + self.corrupt(self.state) + + response_action = {} + response_action['diaact'] = self.state['diaact'] + response_action['inform_slots'] = self.state['inform_slots'] + response_action['request_slots'] = self.state['request_slots'] + response_action['turn'] = self.state['turn'] + response_action['nl'] = "" + + # add NL to dia_act + # self.add_nl_to_action(response_action) + return response_action, self.episode_over, self.dialog_status
+ + +
[docs] def response_confirm_answer(self, system_action): + """ Response for Confirm_Answer (System Action) """ + + if len(self.state['rest_slots']) > 0: + request_slot = random.choice(self.state['rest_slots']) + + if request_slot in self.goal['request_slots'].keys(): + self.state['diaact'] = "request" + self.state['request_slots'][request_slot] = "UNK" + elif request_slot in self.goal['inform_slots'].keys(): + self.state['diaact'] = "inform" + self.state['inform_slots'][request_slot] = self.goal['inform_slots'][request_slot] + if request_slot in self.state['rest_slots']: + self.state['rest_slots'].remove(request_slot) + else: + self.state['diaact'] = "thanks"
+ +
[docs] def response_thanks(self, system_action): + """ Response for Thanks (System Action) """ + + self.episode_over = True + self.dialog_status = SUCCESS_DIALOG + + request_slot_set = deepcopy(list(self.state['request_slots'].keys())) + if 'ticket' in request_slot_set: + request_slot_set.remove('ticket') + rest_slot_set = deepcopy(self.state['rest_slots']) + if 'ticket' in rest_slot_set: + rest_slot_set.remove('ticket') + + if len(request_slot_set) > 0 or len(rest_slot_set) > 0: + self.dialog_status = FAILED_DIALOG + + for info_slot in self.state['history_slots'].keys(): + if self.state['history_slots'][info_slot] == NO_VALUE_MATCH: + self.dialog_status = FAILED_DIALOG + if info_slot in self.goal['inform_slots'].keys(): + if self.state['history_slots'][info_slot] != self.goal['inform_slots'][info_slot]: + self.dialog_status = FAILED_DIALOG + + if 'ticket' in system_action['inform_slots'].keys(): + if system_action['inform_slots']['ticket'] == NO_VALUE_MATCH: + self.dialog_status = FAILED_DIALOG + + if self.constraint_check == CONSTRAINT_CHECK_FAILURE: + self.dialog_status = FAILED_DIALOG
+ +
[docs] def response_request(self, system_action): + """ Response for Request (System Action) """ + + if len(system_action['request_slots'].keys()) > 0: + slot = list(system_action['request_slots'].keys())[0] # only one slot + if slot in self.goal['inform_slots']: # request slot in user's constraints #and slot not in self.state['request_slots'].keys(): + self.state['inform_slots'][slot] = self.goal['inform_slots'][slot] + self.state['diaact'] = "inform" + if slot in self.state['rest_slots']: self.state['rest_slots'].remove(slot) + if slot in self.state['request_slots']: del self.state['request_slots'][slot] + self.state['request_slots'].clear() + elif slot in self.goal['request_slots'] and slot not in self.state['rest_slots'] and slot in self.state['history_slots']: # the requested slot has been answered + self.state['inform_slots'][slot] = self.state['history_slots'][slot] + self.state['request_slots'].clear() + self.state['diaact'] = "inform" + elif slot in self.goal['request_slots'].keys() and slot in self.state['rest_slots']: # request slot in user's goal's request slots, and not answered yet + self.state['diaact'] = "request" # "confirm_question" + self.state['request_slots'][slot] = "UNK" + + ######################################################################## + # Inform the rest of informable slots + ######################################################################## + for info_slot in self.state['rest_slots']: + if info_slot in self.goal['inform_slots'].keys(): + self.state['inform_slots'][info_slot] = self.goal['inform_slots'][info_slot] + + for info_slot in self.state['inform_slots'].keys(): + if info_slot in self.state['rest_slots']: + self.state['rest_slots'].remove(info_slot) + else: + if len(self.state['request_slots']) == 0 and len(self.state['rest_slots']) == 0: + self.state['diaact'] = "thanks" + else: + self.state['diaact'] = "inform" + self.state['inform_slots'][slot] = I_DO_NOT_CARE + else: # this case should not appear + if len(self.state['rest_slots']) > 0: + random_slot = random.choice(self.state['rest_slots']) + if random_slot in self.goal['inform_slots'].keys(): + self.state['inform_slots'][random_slot] = self.goal['inform_slots'][random_slot] + self.state['rest_slots'].remove(random_slot) + self.state['diaact'] = "inform" + elif random_slot in self.goal['request_slots'].keys(): + self.state['request_slots'][random_slot] = self.goal['request_slots'][random_slot] + self.state['diaact'] = "request"
+ +
[docs] def response_multiple_choice(self, system_action): + """ Response for Multiple_Choice (System Action) """ + + slot = system_action['inform_slots'].keys()[0] + if slot in self.goal['inform_slots'].keys(): + self.state['inform_slots'][slot] = self.goal['inform_slots'][slot] + elif slot in self.goal['request_slots'].keys(): + self.state['inform_slots'][slot] = random.choice(system_action['inform_slots'][slot]) + + self.state['diaact'] = "inform" + if slot in self.state['rest_slots']: self.state['rest_slots'].remove(slot) + if slot in self.state['request_slots'].keys(): del self.state['request_slots'][slot]
+ +
[docs] def response_inform(self, system_action): + """ Response for Inform (System Action) """ + + if 'taskcomplete' in system_action['inform_slots'].keys(): # check all the constraints from agents with user goal + self.state['diaact'] = "thanks" + #if 'ticket' in self.state['rest_slots']: self.state['request_slots']['ticket'] = 'UNK' + self.constraint_check = CONSTRAINT_CHECK_SUCCESS + + if system_action['inform_slots']['taskcomplete'] == NO_VALUE_MATCH: + self.state['history_slots']['ticket'] = NO_VALUE_MATCH + if 'ticket' in self.state['rest_slots']: self.state['rest_slots'].remove('ticket') + if 'ticket' in self.state['request_slots'].keys(): del self.state['request_slots']['ticket'] + + for slot in self.goal['inform_slots'].keys(): + # Deny, if the answers from agent can not meet the constraints of user + if slot not in system_action['inform_slots'].keys() or (self.goal['inform_slots'][slot].lower() != system_action['inform_slots'][slot].lower()): + self.state['diaact'] = "deny" + self.state['request_slots'].clear() + self.state['inform_slots'].clear() + self.constraint_check = CONSTRAINT_CHECK_FAILURE + break + else: + for slot in system_action['inform_slots'].keys(): + self.state['history_slots'][slot] = system_action['inform_slots'][slot] + + if slot in self.goal['inform_slots'].keys(): + if system_action['inform_slots'][slot] == self.goal['inform_slots'][slot]: + if slot in self.state['rest_slots']: self.state['rest_slots'].remove(slot) + + if len(self.state['request_slots']) > 0: + self.state['diaact'] = "request" + elif len(self.state['rest_slots']) > 0: + rest_slot_set = deepcopy(self.state['rest_slots']) + if 'ticket' in rest_slot_set: + rest_slot_set.remove('ticket') + + if len(rest_slot_set) > 0: + inform_slot = random.choice(rest_slot_set) # self.state['rest_slots'] + if inform_slot in self.goal['inform_slots'].keys(): + self.state['inform_slots'][inform_slot] = self.goal['inform_slots'][inform_slot] + self.state['diaact'] = "inform" + self.state['rest_slots'].remove(inform_slot) + elif inform_slot in self.goal['request_slots'].keys(): + self.state['request_slots'][inform_slot] = 'UNK' + self.state['diaact'] = "request" + else: + self.state['request_slots']['ticket'] = 'UNK' + self.state['diaact'] = "request" + else: # how to reply here? + self.state['diaact'] = "thanks" # replies "closing"? or replies "confirm_answer" + else: # != value Should we deny here or ? + ######################################################################## + # TODO When agent informs(slot=value), where the value is different with the constraint in user goal, Should we deny or just inform the correct value? + ######################################################################## + self.state['diaact'] = "inform" + self.state['inform_slots'][slot] = self.goal['inform_slots'][slot] + if slot in self.state['rest_slots']: self.state['rest_slots'].remove(slot) + else: + if slot in self.state['rest_slots']: + self.state['rest_slots'].remove(slot) + if slot in self.state['request_slots'].keys(): + del self.state['request_slots'][slot] + + if len(self.state['request_slots']) > 0: + request_set = list(self.state['request_slots'].keys()) + if 'ticket' in request_set: + request_set.remove('ticket') + + if len(request_set) > 0: + request_slot = random.choice(request_set) + else: + request_slot = 'ticket' + + self.state['request_slots'][request_slot] = "UNK" + self.state['diaact'] = "request" + elif len(self.state['rest_slots']) > 0: + rest_slot_set = deepcopy(self.state['rest_slots']) + if 'ticket' in rest_slot_set: + rest_slot_set.remove('ticket') + + if len(rest_slot_set) > 0: + inform_slot = random.choice(rest_slot_set) #self.state['rest_slots'] + if inform_slot in self.goal['inform_slots'].keys(): + self.state['inform_slots'][inform_slot] = self.goal['inform_slots'][inform_slot] + self.state['diaact'] = "inform" + self.state['rest_slots'].remove(inform_slot) + + if 'ticket' in self.state['rest_slots']: + self.state['request_slots']['ticket'] = 'UNK' + self.state['diaact'] = "request" + elif inform_slot in self.goal['request_slots'].keys(): + self.state['request_slots'][inform_slot] = self.goal['request_slots'][inform_slot] + self.state['diaact'] = "request" + else: + self.state['request_slots']['ticket'] = 'UNK' + self.state['diaact'] = "request" + else: + self.state['diaact'] = "thanks" # or replies "confirm_answer"
+ + +
[docs]class StateTracker: + """ The state tracker maintains a record of which request slots are filled and which inform slots are filled """ + + def __init__(self, act_set, slot_set, movie_dictionary): + """ constructor for statetracker takes movie knowledge base and initializes a new episode + + Arguments: + act_set -- The set of all acts availavle + slot_set -- The total set of available slots + movie_dictionary -- A representation of all the available movies. Generally this object is accessed via the KBHelper class + + Class Variables: + history_vectors -- A record of the current dialog so far in vector format (act-slot, but no values) + history_dictionaries -- A record of the current dialog in dictionary format + current_slots -- A dictionary that keeps a running record of which slots are filled current_slots['inform_slots'] and which are requested current_slots['request_slots'] (but not filed) + action_dimension -- # TODO indicates the dimensionality of the vector representaiton of the action + kb_result_dimension -- A single integer denoting the dimension of the kb_results features. + turn_count -- A running count of which turn we are at in the present dialog + """ + self.movie_dictionary = movie_dictionary + self.initialize_episode() + self.history_vectors = None + self.history_dictionaries = None + self.current_slots = None + self.action_dimension = 10 # TODO REPLACE WITH REAL VALUE + self.kb_result_dimension = 10 # TODO REPLACE WITH REAL VALUE + self.turn_count = 0 + self.kb_helper = KBHelper(movie_dictionary) + + +
[docs] def initialize_episode(self): + """ Initialize a new episode (dialog), flush the current state and tracked slots """ + + self.action_dimension = 10 + self.history_vectors = np.zeros((1, self.action_dimension)) + self.history_dictionaries = [] + self.turn_count = 0 + self.current_slots = {} + + self.current_slots['inform_slots'] = {} + self.current_slots['request_slots'] = {} + self.current_slots['proposed_slots'] = {} + self.current_slots['agent_request_slots'] = {}
+ + +
[docs] def dialog_history_vectors(self): + """ Return the dialog history (both user and agent actions) in vector representation """ + return self.history_vectors
+ + +
[docs] def dialog_history_dictionaries(self): + """ Return the dictionary representation of the dialog history (includes values) """ + return self.history_dictionaries
+ + +
[docs] def kb_results_for_state(self): + """ Return the information about the database results based on the currently informed slots """ + ######################################################################## + # TODO Calculate results based on current informed slots + ######################################################################## + kb_results = self.kb_helper.database_results_for_agent(self.current_slots) # replace this with something less ridiculous + # TODO turn results into vector (from dictionary) + results = np.zeros((0, self.kb_result_dimension)) + return results
+ + +
[docs] def get_state_for_agent(self): + """ Get the state representatons to send to agent """ + #state = {'user_action': self.history_dictionaries[-1], 'current_slots': self.current_slots, 'kb_results': self.kb_results_for_state()} + state = {'user_action': self.history_dictionaries[-1], 'current_slots': self.current_slots, #'kb_results': self.kb_results_for_state(), + 'kb_results_dict':self.kb_helper.database_results_for_agent(self.current_slots), 'turn': self.turn_count, 'history': self.history_dictionaries, + 'agent_action': self.history_dictionaries[-2] if len(self.history_dictionaries) > 1 else None} + return deepcopy(state)
+ +
[docs] def get_suggest_slots_values(self, request_slots): + """ Get the suggested values for request slots """ + + suggest_slot_vals = {} + if len(request_slots) > 0: + suggest_slot_vals = self.kb_helper.suggest_slot_values(request_slots, self.current_slots) + + return suggest_slot_vals
+ +
[docs] def get_current_kb_results(self): + """ get the kb_results for current state """ + kb_results = self.kb_helper.available_results_from_kb(self.current_slots) + return kb_results
+ + +
[docs] def update(self, agent_action=None, user_action=None): + """ Update the state based on the latest action """ + + ######################################################################## + # Make sure that the function was called properly + ######################################################################## + assert(not (user_action and agent_action)) + assert(user_action or agent_action) + + ######################################################################## + # Update state to reflect a new action by the agent. + ######################################################################## + if agent_action: + + #################################################################### + # Handles the act_slot response (with values needing to be filled) + #################################################################### + if agent_action['act_slot_response']: + response = deepcopy(agent_action['act_slot_response']) + + inform_slots = self.kb_helper.fill_inform_slots(response['inform_slots'], self.current_slots) # TODO this doesn't actually work yet, remove this warning when kb_helper is functional + agent_action_values = {'turn': self.turn_count, 'speaker': "agent", 'diaact': response['diaact'], 'inform_slots': inform_slots, 'request_slots':response['request_slots']} + + agent_action['act_slot_response'].update({'diaact': response['diaact'], 'inform_slots': inform_slots, 'request_slots':response['request_slots'], 'turn':self.turn_count}) + + elif agent_action['act_slot_value_response']: + agent_action_values = deepcopy(agent_action['act_slot_value_response']) + # print("Updating state based on act_slot_value action from agent") + agent_action_values['turn'] = self.turn_count + agent_action_values['speaker'] = "agent" + + #################################################################### + # This code should execute regardless of which kind of agent produced action + #################################################################### + for slot in agent_action_values['inform_slots'].keys(): + self.current_slots['proposed_slots'][slot] = agent_action_values['inform_slots'][slot] + self.current_slots['inform_slots'][slot] = agent_action_values['inform_slots'][slot] # add into inform_slots + if slot in self.current_slots['request_slots'].keys(): + del self.current_slots['request_slots'][slot] + + for slot in agent_action_values['request_slots'].keys(): + if slot not in self.current_slots['agent_request_slots']: + self.current_slots['agent_request_slots'][slot] = "UNK" + + self.history_dictionaries.append(agent_action_values) + current_agent_vector = np.ones((1, self.action_dimension)) + self.history_vectors = np.vstack([self.history_vectors, current_agent_vector]) + + ######################################################################## + # Update the state to reflect a new action by the user + ######################################################################## + elif user_action: + + #################################################################### + # Update the current slots + #################################################################### + for slot in user_action['inform_slots'].keys(): + self.current_slots['inform_slots'][slot] = user_action['inform_slots'][slot] + if slot in self.current_slots['request_slots'].keys(): + del self.current_slots['request_slots'][slot] + + for slot in user_action['request_slots'].keys(): + if slot not in self.current_slots['request_slots']: + self.current_slots['request_slots'][slot] = "UNK" + + self.history_vectors = np.vstack([self.history_vectors, np.zeros((1,self.action_dimension))]) + new_move = {'turn': self.turn_count, 'speaker': "user", 'request_slots': user_action['request_slots'], 'inform_slots': user_action['inform_slots'], 'diaact': user_action['diaact']} + self.history_dictionaries.append(deepcopy(new_move)) + + ######################################################################## + # This should never happen if the asserts passed + ######################################################################## + else: + pass + + ######################################################################## + # This code should execute after update code regardless of what kind of action (agent/user) + ######################################################################## + self.turn_count += 1
+ + +
[docs]class KBHelper: + """ An assistant to fill in values for the agent (which knows about slots of values) """ + + def __init__(self, movie_dictionary): + """ Constructor for a KBHelper """ + + self.movie_dictionary = movie_dictionary + self.cached_kb = defaultdict(list) + self.cached_kb_slot = defaultdict(list) + + +
[docs] def fill_inform_slots(self, inform_slots_to_be_filled, current_slots): + """ Takes unfilled inform slots and current_slots, returns dictionary of filled informed slots (with values) + + Arguments: + inform_slots_to_be_filled -- Something that looks like {starttime:None, theater:None} where starttime and theater are slots that the agent needs filled + current_slots -- Contains a record of all filled slots in the conversation so far - for now, just use current_slots['inform_slots'] which is a dictionary of the already filled-in slots + + Returns: + filled_in_slots -- A dictionary of form {slot1:value1, slot2:value2} for each sloti in inform_slots_to_be_filled + """ + + kb_results = self.available_results_from_kb(current_slots) + if auto_suggest == 1: + print('Number of movies in KB satisfying current constraints: ', len(kb_results)) + + filled_in_slots = {} + if 'taskcomplete' in inform_slots_to_be_filled.keys(): + filled_in_slots.update(current_slots['inform_slots']) + + for slot in inform_slots_to_be_filled.keys(): + if slot == 'numberofpeople': + if slot in current_slots['inform_slots'].keys(): + filled_in_slots[slot] = current_slots['inform_slots'][slot] + elif slot in inform_slots_to_be_filled.keys(): + filled_in_slots[slot] = inform_slots_to_be_filled[slot] + continue + + if slot == 'ticket' or slot == 'taskcomplete': + filled_in_slots[slot] = TICKET_AVAILABLE if len(kb_results)>0 else NO_VALUE_MATCH + continue + + if slot == 'closing': continue + + #################################################################### + # Grab the value for the slot with the highest count and fill it + #################################################################### + values_dict = self.available_slot_values(slot, kb_results) + + values_counts = [(v, values_dict[v]) for v in values_dict.keys()] + if len(values_counts) > 0: + filled_in_slots[slot] = sorted(values_counts, key = lambda x: -x[1])[0][0] + else: + filled_in_slots[slot] = NO_VALUE_MATCH #"NO VALUE MATCHES SNAFU!!!" + + return filled_in_slots
+ + +
[docs] def available_slot_values(self, slot, kb_results): + """ Return the set of values available for the slot based on the current constraints """ + + slot_values = {} + for movie_id in kb_results.keys(): + if slot in kb_results[movie_id].keys(): + slot_val = kb_results[movie_id][slot] + if slot_val in slot_values.keys(): + slot_values[slot_val] += 1 + else: slot_values[slot_val] = 1 + return slot_values
+ +
[docs] def available_results_from_kb(self, current_slots): + """ Return the available movies in the movie_kb based on the current constraints """ + + ret_result = [] + current_slots = current_slots['inform_slots'] + constrain_keys = current_slots.keys() + + constrain_keys = filter(lambda k : k != 'ticket' and \ + k != 'numberofpeople' and \ + k!= 'taskcomplete' and \ + k != 'closing' , constrain_keys) + constrain_keys = [k for k in constrain_keys if current_slots[k] != I_DO_NOT_CARE] + + query_idx_keys = frozenset(current_slots.items()) + cached_kb_ret = self.cached_kb[query_idx_keys] + + cached_kb_length = len(cached_kb_ret) if cached_kb_ret != None else -1 + if cached_kb_length > 0: + return dict(cached_kb_ret) + elif cached_kb_length == -1: + return dict([]) + + # kb_results = copy.deepcopy(self.movie_dictionary) + for id in self.movie_dictionary.keys(): + kb_keys = self.movie_dictionary[id].keys() + if len(set(constrain_keys).union(set(kb_keys)) ^ (set(constrain_keys) ^ set(kb_keys))) == len( + constrain_keys): + match = True + for idx, k in enumerate(constrain_keys): + if str(current_slots[k]).lower() == str(self.movie_dictionary[id][k]).lower(): + continue + else: + match = False + if match: + self.cached_kb[query_idx_keys].append((id, self.movie_dictionary[id])) + ret_result.append((id, self.movie_dictionary[id])) + + # for slot in current_slots['inform_slots'].keys(): + # if slot == 'ticket' or slot == 'numberofpeople' or slot == 'taskcomplete' or slot == 'closing': continue + # if current_slots['inform_slots'][slot] == dialog_config.I_DO_NOT_CARE: continue + # + # if slot not in self.movie_dictionary[movie_id].keys(): + # if movie_id in kb_results.keys(): + # del kb_results[movie_id] + # else: + # if current_slots['inform_slots'][slot].lower() != self.movie_dictionary[movie_id][slot].lower(): + # if movie_id in kb_results.keys(): + # del kb_results[movie_id] + + if len(ret_result) == 0: + self.cached_kb[query_idx_keys] = None + + ret_result = dict(ret_result) + return ret_result
+ +
[docs] def available_results_from_kb_for_slots(self, inform_slots): + """ Return the count statistics for each constraint in inform_slots """ + + kb_results = {key:0 for key in inform_slots.keys()} + kb_results['matching_all_constraints'] = 0 + + query_idx_keys = frozenset(inform_slots.items()) + cached_kb_slot_ret = self.cached_kb_slot[query_idx_keys] + + if len(cached_kb_slot_ret) > 0: + return cached_kb_slot_ret[0] + + for movie_id in self.movie_dictionary.keys(): + all_slots_match = 1 + for slot in inform_slots.keys(): + if slot == 'ticket' or inform_slots[slot] == I_DO_NOT_CARE: + continue + + if slot in self.movie_dictionary[movie_id]: + # if slot in self.movie_dictionary[movie_id]: + if inform_slots[slot].lower() == self.movie_dictionary[movie_id][slot].lower(): + kb_results[slot] += 1 + else: + all_slots_match = 0 + else: + all_slots_match = 0 + kb_results['matching_all_constraints'] += all_slots_match + + self.cached_kb_slot[query_idx_keys].append(kb_results) + return kb_results
+ + +
[docs] def database_results_for_agent(self, current_slots): + """ A dictionary of the number of results matching each current constraint. The agent needs this to decide what to do next. """ + + database_results ={} # { date:100, distanceconstraints:60, theater:30, matching_all_constraints: 5} + database_results = self.available_results_from_kb_for_slots(current_slots['inform_slots']) + return database_results
+ +
[docs] def suggest_slot_values(self, request_slots, current_slots): + """ Return the suggest slot values """ + + avail_kb_results = self.available_results_from_kb(current_slots) + return_suggest_slot_vals = {} + for slot in request_slots.keys(): + avail_values_dict = self.available_slot_values(slot, avail_kb_results) + values_counts = [(v, avail_values_dict[v]) for v in avail_values_dict.keys()] + + if len(values_counts) > 0: + return_suggest_slot_vals[slot] = [] + sorted_dict = sorted(values_counts, key = lambda x: -x[1]) + for k in sorted_dict: return_suggest_slot_vals[slot].append(k[0]) + else: + return_suggest_slot_vals[slot] = [] + + return return_suggest_slot_vals
+ + + +
[docs]class State(object): + def __init__(self, state=None, reward=None, done=None): + self.states = [state] + self.rewards = [reward] + self.local_done = [done]
+ + +
[docs]class MovieActInActOutEnvironment(object): + def __init__(self, worker_id=None): + self.worker_id = worker_id + self.act_set = act_set + self.slot_set = slot_set + self.movie_dict = movie_dictionary + self.user = RuleSimulator(movie_dictionary, act_set, slot_set, goal_set, usersim_params) + self.state_tracker = StateTracker(act_set, slot_set, movie_dictionary) + self.act_cardinality = len(act_set.keys()) + self.slot_cardinality = len(slot_set.keys()) + self.feasible_actions = feasible_actions + self.num_actions = len(self.feasible_actions) + self.max_turn = agent_params['max_turn'] + 4 + self.state_dimension = 2 * self.act_cardinality + 7 * self.slot_cardinality + 3 + self.max_turn + print(self.num_actions) + print(self.state_dimension) + self.env_info = [State()] + self.stat = {'success':0, 'fail':0} + # self.observation_space = None + # self.action_space = None + +
[docs] def reset(self, train_mode, config): + self.current_slot_id = 0 + self.phase = 0 + self.request_set = ['moviename', 'starttime', 'city', 'date', 'theater', 'numberofpeople'] + self.state_tracker.initialize_episode() + user_action = self.user.initialize_episode() + self.print_function(user_action = user_action) + self.state_tracker.update(user_action = user_action) + state_vector = self.prepare_state_representation(self.state_tracker.get_state_for_agent()) + self.env_info = [State(state_vector, 0, False)] + return self.env_info
+ +
[docs] def step(self, action): + ######################################################################## + # Register AGENT action with the state_tracker + ######################################################################## + agent_action = self.action_decode(action) + self.state_tracker.update(agent_action=agent_action) + self.print_function(agent_action = agent_action['act_slot_response']) + + ######################################################################## + # CALL USER TO TAKE HER TURN + ######################################################################## + sys_action = self.state_tracker.dialog_history_dictionaries()[-1] + user_action, session_over, dialog_status = self.user.next(sys_action) + reward = self.reward_function(dialog_status) + + ######################################################################## + # Update state tracker with latest user action + ######################################################################## + if session_over != True: + self.state_tracker.update(user_action = user_action) + self.print_function(user_action = user_action) + else: + if reward > 0: + self.stat['success'] += 1 + else: self.stat['fail'] += 1 + + state_vector = self.prepare_state_representation(self.state_tracker.get_state_for_agent()) + self.env_info = [State(state_vector, reward, session_over)] + + return self.env_info
+ +
[docs] def reward_function(self, dialog_status): + """ Reward Function 1: a reward function based on the dialog_status """ + if dialog_status == FAILED_DIALOG: + reward = -self.user.max_turn #10 + elif dialog_status == SUCCESS_DIALOG: + reward = 2*self.user.max_turn #20 + else: + reward = -1 + return reward
+ +
[docs] def reward_function_without_penalty(self, dialog_status): + """ Reward Function 2: a reward function without penalty on per turn and failure dialog """ + if dialog_status == FAILED_DIALOG: + reward = 0 + elif dialog_status == SUCCESS_DIALOG: + reward = 2*self.user.max_turn + else: + reward = 0 + return reward
+ +
[docs] def initialize_episode(self): + """ Initialize a new episode. This function is called every time a new episode is run. """ + + self.current_slot_id = 0 + self.phase = 0 + self.request_set = ['moviename', 'starttime', 'city', 'date', 'theater', 'numberofpeople']
+ + +
[docs] def action_decode(self, action): + """ DQN: Input state, output action """ + if isinstance(action, np.ndarray): + action = action[0] + act_slot_response = deepcopy(self.feasible_actions[action]) + return {'act_slot_response': act_slot_response, 'act_slot_value_response': None}
+ + +
[docs] def prepare_state_representation(self, state): + """ Create the representation for each state """ + + user_action = state['user_action'] + current_slots = state['current_slots'] + kb_results_dict = state['kb_results_dict'] + agent_last = state['agent_action'] + + ######################################################################## + # Create one-hot of acts to represent the current user action + ######################################################################## + user_act_rep = np.zeros((1, self.act_cardinality)) + user_act_rep[0,self.act_set[user_action['diaact']]] = 1.0 + + ######################################################################## + # Create bag of inform slots representation to represent the current user action + ######################################################################## + user_inform_slots_rep = np.zeros((1, self.slot_cardinality)) + for slot in user_action['inform_slots'].keys(): + user_inform_slots_rep[0,self.slot_set[slot]] = 1.0 + + ######################################################################## + # Create bag of request slots representation to represent the current user action + ######################################################################## + user_request_slots_rep = np.zeros((1, self.slot_cardinality)) + for slot in user_action['request_slots'].keys(): + user_request_slots_rep[0, self.slot_set[slot]] = 1.0 + + ######################################################################## + # Creat bag of filled_in slots based on the current_slots + ######################################################################## + current_slots_rep = np.zeros((1, self.slot_cardinality)) + for slot in current_slots['inform_slots']: + current_slots_rep[0, self.slot_set[slot]] = 1.0 + + ######################################################################## + # Encode last agent act + ######################################################################## + agent_act_rep = np.zeros((1,self.act_cardinality)) + if agent_last: + agent_act_rep[0, self.act_set[agent_last['diaact']]] = 1.0 + + ######################################################################## + # Encode last agent inform slots + ######################################################################## + agent_inform_slots_rep = np.zeros((1, self.slot_cardinality)) + if agent_last: + for slot in agent_last['inform_slots'].keys(): + agent_inform_slots_rep[0,self.slot_set[slot]] = 1.0 + + ######################################################################## + # Encode last agent request slots + ######################################################################## + agent_request_slots_rep = np.zeros((1, self.slot_cardinality)) + if agent_last: + for slot in agent_last['request_slots'].keys(): + agent_request_slots_rep[0,self.slot_set[slot]] = 1.0 + + turn_rep = np.zeros((1,1)) + state['turn'] / 10. + + ######################################################################## + # One-hot representation of the turn count? + ######################################################################## + turn_onehot_rep = np.zeros((1, self.max_turn)) + turn_onehot_rep[0, state['turn']] = 1.0 + + ######################################################################## + # Representation of KB results (scaled counts) + ######################################################################## + kb_count_rep = np.zeros((1, self.slot_cardinality + 1)) + kb_results_dict['matching_all_constraints'] / 100. + for slot in kb_results_dict: + if slot in self.slot_set: + kb_count_rep[0, self.slot_set[slot]] = kb_results_dict[slot] / 100. + + ######################################################################## + # Representation of KB results (binary) + ######################################################################## + kb_binary_rep = np.zeros((1, self.slot_cardinality + 1)) + np.sum( kb_results_dict['matching_all_constraints'] > 0.) + for slot in kb_results_dict: + if slot in self.slot_set: + kb_binary_rep[0, self.slot_set[slot]] = np.sum( kb_results_dict[slot] > 0.) + + self.final_representation = np.squeeze(np.hstack([user_act_rep, user_inform_slots_rep, user_request_slots_rep, agent_act_rep, agent_inform_slots_rep, agent_request_slots_rep, current_slots_rep, turn_rep, turn_onehot_rep, kb_binary_rep, kb_count_rep])) + return self.final_representation
+ +
[docs] def action_index(self, act_slot_response): + """ Return the index of action """ + + for (i, action) in enumerate(self.feasible_actions): + if act_slot_response == action: + return i + print(act_slot_response) + raise Exception("action index not found") + return None
+ +
[docs] def print_function(self, agent_action=None, user_action=None): + """ Print Function """ + + if agent_action: + if run_mode == 0: + print("Turn %d sys: %s" % (agent_action['turn'], agent_action['nl'])) + elif run_mode == 1: + print("Turn %d sys: %s, inform_slots: %s, request slots: %s" % (agent_action['turn'], agent_action['diaact'], agent_action['inform_slots'], agent_action['request_slots'])) + elif run_mode == 2: # debug mode + print("Turn %d sys: %s, inform_slots: %s, request slots: %s" % (agent_action['turn'], agent_action['diaact'], agent_action['inform_slots'], agent_action['request_slots'])) + print("Turn %d sys: %s" % (agent_action['turn'], agent_action['nl'])) + + if auto_suggest == 1: + print('(Suggested Values: %s)' % (self.state_tracker.get_suggest_slots_values(agent_action['request_slots']))) + elif user_action: + if run_mode == 0: + print ("Turn %d usr: %s" % (user_action['turn'], user_action['nl'])) + elif run_mode == 1: + print ("Turn %s usr: %s, inform_slots: %s, request_slots: %s" % (user_action['turn'], user_action['diaact'], user_action['inform_slots'], user_action['request_slots'])) + elif run_mode == 2: # debug mode, show both + print ("Turn %d usr: %s, inform_slots: %s, request_slots: %s" % (user_action['turn'], user_action['diaact'], user_action['inform_slots'], user_action['request_slots'])) + print ("Turn %d usr: %s" % (user_action['turn'], user_action['nl']))
+ +
[docs] def rule_policy(self): + """ Rule Policy """ + + if self.current_slot_id < len(self.request_set): + slot = self.request_set[self.current_slot_id] + self.current_slot_id += 1 + + act_slot_response = {} + act_slot_response['diaact'] = "request" + act_slot_response['inform_slots'] = {} + act_slot_response['request_slots'] = {slot: "UNK"} + elif self.phase == 0: + act_slot_response = {'diaact': "inform", 'inform_slots': {'taskcomplete': "PLACEHOLDER"}, 'request_slots': {} } + self.phase += 1 + elif self.phase == 1: + act_slot_response = {'diaact': "thanks", 'inform_slots': {}, 'request_slots': {} } + + return self.action_index(act_slot_response)
+ +
[docs] def close(self): + print('\nstatistics: %s' % (self.stat)) + try: + print('\nsuccess rate:', (self.stat['success']/(self.stat['success'] + self.stat['fail']))) + except: + pass + print("close")
+ + +
[docs]class MovieEnv(BaseEnv): + ''' + Wrapper for Unity ML-Agents env to work with the Lab. + + e.g. env_spec + "env": [{ + "name": "gridworld", + "max_t": 20, + "max_tick": 3, + "unity": { + "gridSize": 6, + "numObstacles": 2, + "numGoals": 1 + } + }], + ''' + + def __init__(self, spec, e=None, env_space=None): + super(MovieEnv, self).__init__(spec, e, env_space) + util.set_attr(self, self.env_spec, [ + 'observation_dim', + 'action_dim', + ]) + worker_id = int(f'{os.getpid()}{self.e+int(ps.unique_id())}'[-4:]) + # TODO dynamically compose components according to env_spec + self.u_env = MovieActInActOutEnvironment(worker_id) + self.patch_gym_spaces(self.u_env) + self._set_attr_from_u_env(self.u_env) + # assert self.max_t is not None + if env_space is None: # singleton mode + pass + else: + self.space_init(env_space) + + logger.info(util.self_desc(self)) + +
[docs] def patch_gym_spaces(self, u_env): + ''' + For standardization, use gym spaces to represent observation and action spaces. + This method iterates through the multiple brains (multiagent) then constructs and returns lists of observation_spaces and action_spaces + ''' + observation_shape = (self.env_spec.get('observation_dim'),) + observation_space = spaces.Box(low=0, high=1, shape=observation_shape, dtype=np.int32) + set_gym_space_attr(observation_space) + action_space = spaces.Discrete(self.env_spec.get('action_dim')) + set_gym_space_attr(action_space) + # set for singleton + u_env.observation_space = observation_space + u_env.action_space = action_space
+ + def _get_env_info(self, env_info_dict, a): + '''''' + return self.u_env.env_info[a] + +
[docs] @lab_api + def reset(self): + _reward = np.nan + env_info_dict = self.u_env.reset(train_mode=(util.get_lab_mode() != 'dev'), config=self.env_spec.get('multiwoz')) + a, b = 0, 0 # default singleton aeb + env_info_a = self._get_env_info(env_info_dict, a) + state = env_info_a.states[b] + self.done = done = False + logger.debug(f'Env {self.e} reset reward: {_reward}, state: {state}, done: {done}') + return _reward, state, done
+ +
[docs] @lab_api + def step(self, action): + env_info_dict = self.u_env.step(action) + a, b = 0, 0 # default singleton aeb + env_info_a = self._get_env_info(env_info_dict, a) + reward = env_info_a.rewards[b] * self.reward_scale + state = env_info_a.states[b] + done = env_info_a.local_done[b] + self.done = done = done or self.clock.t > self.max_t + logger.debug(f'Env {self.e} step reward: {reward}, state: {state}, done: {done}') + return reward, state, done
+ +
[docs] @lab_api + def close(self): + self.u_env.close()
+ + # NOTE optional extension for multi-agent-env + +
[docs] @lab_api + def space_init(self, env_space): + '''Post init override for space env. Note that aeb is already correct from __init__''' + self.env_space = env_space + self.aeb_space = env_space.aeb_space + self.observation_spaces = [self.observation_space] + self.action_spaces = [self.action_space]
+ +
[docs] @lab_api + def space_reset(self): + self._check_u_brain_to_agent() + self.done = False + env_info_dict = self.u_env.reset(train_mode=(util.get_lab_mode() != 'dev'), config=self.env_spec.get('multiwoz')) + _reward_e, state_e, done_e = self.env_space.aeb_space.init_data_s(ENV_DATA_NAMES, e=self.e) + for (a, b), body in util.ndenumerate_nonan(self.body_e): + env_info_a = self._get_env_info(env_info_dict, a) + self._check_u_agent_to_body(env_info_a, a) + state = env_info_a.states[b] + state_e[(a, b)] = state + done_e[(a, b)] = self.done + logger.debug(f'Env {self.e} reset reward_e: {_reward_e}, state_e: {state_e}, done_e: {done_e}') + return _reward_e, state_e, done_e
+ +
[docs] @lab_api + def space_step(self, action_e): + # TODO implement clock_speed: step only if self.clock.to_step() + if self.done: + return self.space_reset() + action_e = util.nanflatten(action_e) + env_info_dict = self.u_env.step(action_e) + reward_e, state_e, done_e = self.env_space.aeb_space.init_data_s(ENV_DATA_NAMES, e=self.e) + for (a, b), body in util.ndenumerate_nonan(self.body_e): + env_info_a = self._get_env_info(env_info_dict, a) + reward_e[(a, b)] = env_info_a.rewards[b] * self.reward_scale + state_e[(a, b)] = env_info_a.states[b] + done_e[(a, b)] = env_info_a.local_done[b] + self.done = (util.nonan_all(done_e) or self.clock.t > self.max_t) + logger.debug(f'Env {self.e} step reward_e: {reward_e}, state_e: {state_e}, done_e: {done_e}') + return reward_e, state_e, done_e
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/env/multiwoz.html b/docs/build/html/_modules/convlab/env/multiwoz.html new file mode 100644 index 0000000..50870a2 --- /dev/null +++ b/docs/build/html/_modules/convlab/env/multiwoz.html @@ -0,0 +1,451 @@ + + + + + + + + + + + convlab.env.multiwoz — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.env.multiwoz

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import math
+import os
+from copy import deepcopy
+
+import numpy as np
+import pydash as ps
+from gym import spaces
+
+from convlab import evaluator
+import convlab.modules.nlg.multiwoz as nlg
+import convlab.modules.nlu.multiwoz as nlu
+import convlab.modules.policy.system.multiwoz as sys_policy
+import convlab.modules.policy.user.multiwoz as user_policy
+from convlab.modules.policy.user.multiwoz import UserPolicyAgendaMultiWoz
+from convlab.modules.usr import UserSimulator
+from convlab.env.base import BaseEnv, set_gym_space_attr
+from convlab.lib import logger, util
+from convlab.lib.decorator import lab_api
+from convlab.modules.action_decoder.multiwoz.multiwoz_vocab_action_decoder import ActionVocab
+from convlab.modules.policy.system.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
+
+logger = logger.get_logger(__name__)
+
+
[docs]class State(object): + def __init__(self, state=None, reward=None, done=None): + self.states = [state] + self.rewards = [reward] + self.local_done = [done]
+ + +
[docs]class MultiWozEnvironment(object): + def __init__(self, env_spec, worker_id=None, action_dim=300): + self.env_spec = env_spec + self.worker_id = worker_id + self.observation_space = None + self.action_space = None + + self.agenda = UserPolicyAgendaMultiWoz() # Agenda-based Simulator (act-in act-out) + if 'user_policy' in self.env_spec: + params = deepcopy(ps.get(self.env_spec, 'user_policy')) + AgendaClass = getattr(user_policy, params.pop('name')) + self.agenda = AgendaClass() + + self.nlu = None + if 'nlu' in self.env_spec: + params = deepcopy(ps.get(self.env_spec, 'nlu')) + if not params['name']: + self.nlu = None + else: + NluClass = getattr(nlu, params.pop('name')) + self.nlu = NluClass(**params) + + self.nlg = None + if 'nlg' in self.env_spec: + params = deepcopy(ps.get(self.env_spec, 'nlg')) + if not params['name']: + self.nlg = None + else: + NlgClass = getattr(nlg, params.pop('name')) + self.nlg = NlgClass(**params) + + self.sys_policy = RuleBasedMultiwozBot() + if 'sys_policy' in self.env_spec: + params = deepcopy(ps.get(self.env_spec, 'sys_policy')) + SysPolicy = getattr(sys_policy, params.pop('name')) + self.sys_policy = SysPolicy() + + self.evaluator = None + if 'evaluator' in self.env_spec: + params = deepcopy(ps.get(self.env_spec, 'evaluator')) + EvaluatorClass = getattr(evaluator, params.pop('name')) + self.evaluator = EvaluatorClass(**params) + + self.simulator = UserSimulator(self.nlu, self.agenda, self.nlg) + self.simulator.init_session() + self.action_vocab = ActionVocab(num_actions=action_dim) + self.history = [] + self.last_act = None + + self.stat = {'success':0, 'fail':0} + +
[docs] def reset(self, train_mode, config): + self.simulator.init_session() + self.history = [] + user_response, user_act, session_over, reward = self.simulator.response("null", self.history) + self.last_act = user_act + logger.act(f'User action: {user_act}') + self.history.extend(["null", f'{user_response}']) + self.env_info = [State(user_response, 0., session_over)] + # update evaluator + if self.evaluator: + self.evaluator.add_goal(self.get_goal()) + logger.act(f'Goal: {self.get_goal()}') + return self.env_info
+ +
[docs] def get_goal(self): + return deepcopy(self.simulator.policy.domain_goals)
+ +
[docs] def get_last_act(self): + return deepcopy(self.last_act)
+ +
[docs] def get_sys_act(self): + return deepcopy(self.simulator.sys_act)
+ +
[docs] def step(self, action): + user_response, user_act, session_over, reward = self.simulator.response(action, self.history) + self.last_act = user_act + self.history.extend([f'sys_response', f'user_response']) + logger.act(f'Inferred system action: {self.get_sys_act()}') + # update evaluator + if self.evaluator: + self.evaluator.add_sys_da(self.get_sys_act()) + self.evaluator.add_usr_da(self.get_last_act()) + if session_over: + reward = 2.0 * self.simulator.policy.max_turn if self.evaluator.task_success() else -1.0 * self.simulator.policy.max_turn + else: + reward = -1.0 + self.env_info = [State(user_response, reward, session_over)] + return self.env_info
+ +
[docs] def rule_policy(self, state, algorithm, body): + def find_best_delex_act(action): + def _score(a1, a2): + score = 0 + for domain_act in a1: + if domain_act not in a2: + score += len(a1[domain_act]) + else: + score += len(set(a1[domain_act]) - set(a2[domain_act])) + return score + + best_p_action_index = -1 + best_p_score = math.inf + best_pn_action_index = -1 + best_pn_score = math.inf + for i, v_action in enumerate(self.action_vocab.vocab): + if v_action == action: + return i + else: + p_score = _score(action, v_action) + n_score = _score(v_action, action) + if p_score > 0 and n_score == 0 and p_score < best_p_score: + best_p_action_index = i + best_p_score = p_score + else: + if p_score + n_score < best_pn_score: + best_pn_action_index = i + best_pn_score = p_score + n_score + if best_p_action_index >= 0: + return best_p_action_index + return best_pn_action_index + + rule_act = self.sys_policy.predict(state) + delex_act = {} + for domain_act in rule_act: + domain, act_type = domain_act.split('-', 1) + if act_type in ['NoOffer', 'OfferBook']: + delex_act[domain_act] = ['none'] + elif act_type in ['Select']: + for sv in rule_act[domain_act]: + if sv[0] != "none": + delex_act[domain_act] = [sv[0]] + break + else: + delex_act[domain_act] = [sv[0] for sv in rule_act[domain_act]] + action = find_best_delex_act(delex_act) + + return action
+ +
[docs] def close(self): + pass
+ + +
[docs]class MultiWozEnv(BaseEnv): + ''' + Wrapper for Unity ML-Agents env to work with the Lab. + + e.g. env_spec + "env": [{ + "name": "gridworld", + "max_t": 20, + "max_tick": 3, + "unity": { + "gridSize": 6, + "numObstacles": 2, + "numGoals": 1 + } + }], + ''' + + def __init__(self, spec, e=None): + super(MultiWozEnv, self).__init__(spec, e) + self.action_dim = self.observation_dim = 0 + util.set_attr(self, self.env_spec, [ + 'observation_dim', + 'action_dim', + ]) + worker_id = int(f'{os.getpid()}{self.e+int(ps.unique_id())}'[-4:]) + self.u_env = MultiWozEnvironment(self.env_spec, worker_id, self.action_dim) + self.evaluator = self.u_env.evaluator + self.patch_gym_spaces(self.u_env) + self._set_attr_from_u_env(self.u_env) + + logger.info(util.self_desc(self)) + +
[docs] def patch_gym_spaces(self, u_env): + ''' + For standardization, use gym spaces to represent observation and action spaces. + This method iterates through the multiple brains (multiagent) then constructs and returns lists of observation_spaces and action_spaces + ''' + observation_shape = (self.observation_dim,) + observation_space = spaces.Box(low=0, high=1, shape=observation_shape, dtype=np.int32) + set_gym_space_attr(observation_space) + action_space = spaces.Discrete(self.action_dim) + set_gym_space_attr(action_space) + # set for singleton + u_env.observation_space = observation_space + u_env.action_space = action_space
+ + def _get_env_info(self, env_info_dict, a): + '''''' + return self.u_env.env_info[a] + +
[docs] @lab_api + def reset(self): + # _reward = np.nan + env_info_dict = self.u_env.reset(train_mode=(util.get_lab_mode() != 'dev'), config=self.env_spec.get('multiwoz')) + a, b = 0, 0 # default singleton aeb + env_info_a = self._get_env_info(env_info_dict, a) + state = env_info_a.states[b] + self.done = False + logger.debug(f'Env {self.e} reset state: {state}') + return state
+ +
[docs] @lab_api + def step(self, action): + env_info_dict = self.u_env.step(action) + a, b = 0, 0 # default singleton aeb + env_info_a = self._get_env_info(env_info_dict, a) + reward = env_info_a.rewards[b] # * self.reward_scale + state = env_info_a.states[b] + done = env_info_a.local_done[b] + self.done = done = done or self.clock.t > self.max_t + logger.debug(f'Env {self.e} step reward: {reward}, state: {state}, done: {done}') + return state, reward, done, env_info_a
+ +
[docs] @lab_api + def close(self): + self.u_env.close()
+ +
[docs] def get_goal(self): + return self.u_env.get_goal()
+ +
[docs] def get_last_act(self): + return self.u_env.get_last_act()
+ +
[docs] def get_sys_act(self): + return self.u_env.get_sys_act()
+ +
[docs] def get_task_success(self): + return self.u_env.simulator.policy.goal.task_complete()
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/evaluator/evaluator.html b/docs/build/html/_modules/convlab/evaluator/evaluator.html new file mode 100644 index 0000000..c2384ad --- /dev/null +++ b/docs/build/html/_modules/convlab/evaluator/evaluator.html @@ -0,0 +1,238 @@ + + + + + + + + + + + convlab.evaluator.evaluator — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.evaluator.evaluator
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.evaluator.evaluator

+# -*- coding: utf-8 -*-
+
+
[docs]class Evaluator(object): + def __init__(self): + raise NotImplementedError + +
[docs] def add_goal(self, goal): + """ + init goal and array + args: + goal: dict[domain] dict['info'/'book'/'reqt'] dict/dict/list[slot] + """ + raise NotImplementedError
+ +
[docs] def add_sys_da(self, da_turn): + """ + add sys_da into array + args: + da_turn: dict[domain-intent] list[slot, value] + """ + raise NotImplementedError
+ +
[docs] def add_usr_da(self, da_turn): + """ + add usr_da into array + args: + da_turn: dict[domain-intent] list[slot, value] + """ + raise NotImplementedError
+ +
[docs] def book_rate(self, ref2goal=True, aggregate=True): + """ + judge if the selected entity meets the constraint + """ + raise NotImplementedError
+ +
[docs] def inform_F1(self, ref2goal=True, aggregate=True): + """ + judge if all the requested information is answered + """ + raise NotImplementedError
+ +
[docs] def task_success(self, ref2goal=True): + """ + judge if all the domains are successfully completed + """ + raise NotImplementedError
+ +
[docs] def domain_success(self, domain, ref2goal=True): + """ + judge if the domain (subtask) is successfully completed + """ + raise NotImplementedError
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/evaluator/multiwoz.html b/docs/build/html/_modules/convlab/evaluator/multiwoz.html new file mode 100644 index 0000000..bdb68a6 --- /dev/null +++ b/docs/build/html/_modules/convlab/evaluator/multiwoz.html @@ -0,0 +1,472 @@ + + + + + + + + + + + convlab.evaluator.multiwoz — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.evaluator.multiwoz
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.evaluator.multiwoz

+# -*- coding: utf-8 -*-
+
+import re
+import numpy as np
+from copy import deepcopy
+
+from convlab.evaluator.evaluator import Evaluator
+from convlab.modules.util.multiwoz.dbquery import dbs
+
+requestable = \
+{'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'],
+ 'restaurant': ['addr', 'phone', 'post', 'ref', 'price', 'area', 'food'],
+ 'train': ['ticket', 'time', 'ref', 'id', 'arrive', 'leave'],
+ 'hotel': ['addr', 'post', 'phone', 'ref', 'price', 'internet', 'parking', 'area', 'type', 'stars'],
+ 'taxi': ['car', 'phone'],
+ 'hospital': ['post', 'phone', 'addr'],
+ 'police': ['addr', 'post', 'phone']}
+
+belief_domains = requestable.keys()
+
+mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange'},
+        'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name', 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type'},
+        'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone', 'post': 'postcode', 'type': 'type'},
+        'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination', 'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price'},
+        'taxi': {'car': 'car type', 'phone': 'phone'},
+        'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'},
+        'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}}
+
+
[docs]class MultiWozEvaluator(Evaluator): + def __init__(self): + self.sys_da_array = [] + self.usr_da_array = [] + self.goal = {} + self.cur_domain = '' + self.booked = {} + + def _init_dict(self): + dic = {} + for domain in belief_domains: + dic[domain] = {'info':{}, 'book':{}, 'reqt':[]} + return dic + + def _init_dict_booked(self): + dic = {} + for domain in belief_domains: + dic[domain] = None + return dic + + def _expand(self, _goal): + goal = deepcopy(_goal) + for domain in belief_domains: + if domain not in goal: + goal[domain] = {'info':{}, 'book':{}, 'reqt':[]} + continue + if 'info' not in goal[domain]: + goal[domain]['info'] = {} + if 'book' not in goal[domain]: + goal[domain]['book'] = {} + if 'reqt' not in goal[domain]: + goal[domain]['reqt'] = [] + return goal + +
[docs] def add_goal(self, goal): + """ + init goal and array + args: + goal: dict[domain] dict['info'/'book'/'reqt'] dict/dict/list[slot] + """ + self.sys_da_array = [] + self.usr_da_array = [] + self.goal = goal + self.cur_domain = '' + self.booked = self._init_dict_booked()
+ +
[docs] def add_sys_da(self, da_turn): + """ + add sys_da into array + args: + da_turn: dict[domain-intent] list[slot, value] + """ + for dom_int in da_turn: + domain = dom_int.split('-')[0].lower() + if domain in belief_domains and domain != self.cur_domain: + self.cur_domain = domain + slot_pair = da_turn[dom_int] + for slot, value in slot_pair: + da = (dom_int +'-'+slot).lower() + value = str(value) + self.sys_da_array.append(da+'-'+value) + + if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']: + if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value): + self.booked[self.cur_domain] = dbs[self.cur_domain][int(value)] + elif da == 'train-offerbook-ref' or da == 'train-inform-ref': + if not self.booked['train'] and re.match(r'^\d{8}$', value): + self.booked['train'] = dbs['train'][int(value)] + elif da == 'taxi-inform-car': + if not self.booked['taxi']: + self.booked['taxi'] = 'booked'
+ +
[docs] def add_usr_da(self, da_turn): + """ + add usr_da into array + args: + da_turn: dict[domain-intent] list[slot, value] + """ + for dom_int in da_turn: + domain = dom_int.split('-')[0].lower() + if domain in belief_domains and domain != self.cur_domain: + self.cur_domain = domain + slot_pair = da_turn[dom_int] + for slot, value in slot_pair: + da = (dom_int +'-'+slot).lower() + value = str(value) + self.usr_da_array.append(da+'-'+value)
+ + def _book_rate_goal(self, goal, booked_entity, domains=None): + """ + judge if the selected entity meets the constraint + """ + if domains is None: + domains = belief_domains + score = [] + for domain in domains: + if goal[domain]['book']: + tot = len(goal[domain]['info'].keys()) + if tot == 0: + continue + entity = booked_entity[domain] + if entity is None: + score.append(0) + continue + if domain == 'taxi': + score.append(1) + continue + match = 0 + for k, v in goal[domain]['info'].items(): + if k in ['destination', 'departure', 'name']: + tot -= 1 + elif k == 'leaveAt': + try: + v_constraint = int(v.split(':')[0]) * 100 + int(v.split(':')[1]) + v_select = int(entity['leaveAt'].split(':')[0]) * 100 + int(entity['leaveAt'].split(':')[1]) + if v_constraint <= v_select: + match += 1 + except (ValueError, IndexError): + match += 1 + elif k == 'arriveBy': + try: + v_constraint = int(v.split(':')[0]) * 100 + int(v.split(':')[1]) + v_select = int(entity['arriveBy'].split(':')[0]) * 100 + int(entity['arriveBy'].split(':')[1]) + if v_constraint >= v_select: + match += 1 + except (ValueError, IndexError): + match += 1 + else: + if v.strip() == entity[k].strip(): + match += 1 + if tot != 0: + score.append(match / tot) + return score + + def _inform_F1_goal(self, goal, sys_history, domains=None): + """ + judge if all the requested information is answered + """ + if domains is None: + domains = belief_domains + inform_slot = {} + for domain in domains: + inform_slot[domain] = set() + for da in sys_history: + domain, intent, slot, value = da.split('-', 3) + if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and domain in domains and slot in mapping[domain]: + inform_slot[domain].add(mapping[domain][slot]) + TP, FP, FN = 0, 0, 0 + for domain in domains: + for k in goal[domain]['reqt']: + if k in inform_slot[domain]: + TP += 1 + else: + FN += 1 + for k in inform_slot[domain]: + # exclude slots that are informed by users + if k not in goal[domain]['reqt'] \ + and k not in goal[domain]['info'] \ + and k in requestable[domain]: + FP += 1 + return TP, FP, FN + +
[docs] def book_rate(self, ref2goal=True, aggregate=True): + if ref2goal: + goal = self._expand(self.goal) + else: + goal = self._init_dict() + for domain in belief_domains: + if domain in self.goal and 'book' in self.goal[domain]: + goal[domain]['book'] = self.goal[domain]['book'] + for da in self.usr_da_array: + d, i, s, v = da.split('-', 3) + if i == 'inform' and s in mapping[d]: + goal[d]['info'][mapping[d][s]] = v + score = self._book_rate_goal(goal, self.booked) + if aggregate: + return np.mean(score) if score else None + else: + return score
+ +
[docs] def inform_F1(self, ref2goal=True, aggregate=True): + if ref2goal: + goal = self._expand(self.goal) + else: + goal = self._init_dict() + for da in self.usr_da_array: + d, i, s, v = da.split('-', 3) + if i == 'inform' and s in mapping[d]: + goal[d]['info'][mapping[d][s]] = v + elif i == 'request': + goal[d]['reqt'].append(s) + TP, FP, FN = self._inform_F1_goal(goal, self.sys_da_array) + if aggregate: + try: + rec = TP / (TP + FN) + except ZeroDivisionError: + return None, None, None + try: + prec = TP / (TP + FP) + F1 = 2 * prec * rec / (prec + rec) + except ZeroDivisionError: + return 0, rec, 0 + return prec, rec, F1 + else: + return [TP, FP, FN]
+ +
[docs] def task_success(self, ref2goal=True): + """ + judge if all the domains are successfully completed + """ + book_sess = self.book_rate(ref2goal) + inform_sess = self.inform_F1(ref2goal) + # book rate == 1 & inform recall == 1 + if (book_sess == 1 and inform_sess[1] == 1) \ + or (book_sess == 1 and inform_sess[1] is None) \ + or (book_sess is None and inform_sess[1] == 1): + return 1 + else: + return 0
+ +
[docs] def domain_success(self, domain, ref2goal=True): + """ + judge if the domain (subtask) is successfully completed + """ + if domain not in self.goal: + return None + + if ref2goal: + goal = {} + goal[domain] = deepcopy(self.goal[domain]) + else: + goal = {} + goal[domain] = {'info':{}, 'book':{}, 'reqt':[]} + if 'book' in self.goal[domain]: + goal[domain]['book'] = self.goal[domain]['book'] + for da in self.usr_da_array: + d, i, s, v = da.split('-', 3) + if d != domain: + continue + if i == 'inform' and s in mapping[d]: + goal[d]['info'][mapping[d][s]] = v + elif i == 'request': + goal[d]['reqt'].append(s) + + book_rate = self._book_rate_goal(goal, self.booked, [domain]) + book_rate = np.mean(book_rate) if book_rate else None + + inform = self._inform_F1_goal(goal, self.sys_da_array, [domain]) + try: + inform_rec = inform[0] / (inform[0] + inform[2]) + except ZeroDivisionError: + inform_rec = None + + if (book_rate == 1 and inform_rec == 1) \ + or (book_rate == 1 and inform_rec is None) \ + or (book_rate is None and inform_rec == 1): + return 1 + else: + return 0
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/experiment/analysis.html b/docs/build/html/_modules/convlab/experiment/analysis.html new file mode 100644 index 0000000..fd6d485 --- /dev/null +++ b/docs/build/html/_modules/convlab/experiment/analysis.html @@ -0,0 +1,394 @@ + + + + + + + + + + + convlab.experiment.analysis — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.experiment.analysis
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.experiment.analysis

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import shutil
+
+import numpy as np
+import pandas as pd
+import pydash as ps
+import torch
+
+from convlab.lib import logger, util, viz
+
+NUM_EVAL = 4
+METRICS_COLS = [
+    'strength', 'max_strength', 'final_strength',
+    'sample_efficiency', 'training_efficiency',
+    'stability', 'consistency',
+]
+
+logger = logger.get_logger(__name__)
+
+
+# methods to generate returns (total rewards)
+
+
[docs]def gen_return(agent, env): + '''Generate return for an agent and an env in eval mode''' + obs = env.reset() + agent.reset(obs) + done = False + total_reward = 0 + env.clock.tick('epi') + env.clock.tick('t') + while not done: + action = agent.act(obs) + next_obs, reward, done, info = env.step(action) + agent.update(obs, action, reward, next_obs, done) + obs = next_obs + total_reward += reward + env.clock.tick('t') + return total_reward
+ + +
[docs]def gen_avg_return(agent, env, num_eval=NUM_EVAL): + '''Generate average return for agent and an env''' + with util.ctx_lab_mode('eval'): # enter eval context + agent.algorithm.update() # set explore_var etc. to end_val under ctx + with torch.no_grad(): + returns = [gen_return(agent, env) for i in range(num_eval)] + # exit eval context, restore variables simply by updating + agent.algorithm.update() + return np.mean(returns)
+ + +
[docs]def gen_result(agent, env): + '''Generate average return for agent and an env''' + with util.ctx_lab_mode('eval'): # enter eval context + agent.algorithm.update() # set explore_var etc. to end_val under ctx + with torch.no_grad(): + _return = gen_return(agent, env) + # exit eval context, restore variables simply by updating + agent.algorithm.update() + return _return
+ + +
[docs]def gen_avg_result(agent, env, num_eval=NUM_EVAL): + returns, lens, successes, precs, recs, f1s, book_rates = [], [], [], [], [], [], [] + for _ in range(num_eval): + returns.append(gen_result(agent, env)) + lens.append(env.clock.t) + if env.evaluator: + successes.append(env.evaluator.task_success()) + _p, _r, _f1 = env.evaluator.inform_F1() + if _f1 is not None: + precs.append(_p) + recs.append(_r) + f1s.append(_f1) + _book = env.evaluator.book_rate() + if _book is not None: + book_rates.append(_book) + elif hasattr(env, 'get_task_success'): + successes.append(env.get_task_success()) + logger.nl(f'---A dialog session is done---') + mean_success = None if len(successes) == 0 else np.mean(successes) + mean_p = None if len(precs) == 0 else np.mean(precs) + mean_r = None if len(recs) == 0 else np.mean(recs) + mean_f1 = None if len(f1s) == 0 else np.mean(f1s) + mean_book_rate = None if len(book_rates) == 0 else np.mean(book_rates) + return np.mean(returns), np.mean(lens), mean_success, mean_p, mean_r, mean_f1, mean_book_rate
+ + +
[docs]def calc_session_metrics(session_df, env_name, info_prepath=None, df_mode=None): + ''' + Calculate the session metrics: strength, efficiency, stability + @param DataFrame:session_df Dataframe containing reward, frame, opt_step + @param str:env_name Name of the environment to get its random baseline + @param str:info_prepath Optional info_prepath to auto-save the output to + @param str:df_mode Optional df_mode to save with info_prepath + @returns dict:metrics Consists of scalar metrics and series local metrics + ''' + mean_return = session_df['avg_return'] if df_mode == 'eval' else session_df['avg_return'] + mean_length = session_df['avg_len'] if df_mode == 'eval' else None + mean_success = session_df['avg_success'] if df_mode == 'eval' else None + frames = session_df['frame'] + opt_steps = session_df['opt_step'] + + # all the session local metrics + local = { + 'mean_return': mean_return, + 'mean_length': mean_length, + 'mean_success': mean_success, + 'frames': frames, + 'opt_steps': opt_steps, + } + metrics = { + 'local': local, + } + if info_prepath is not None: # auto-save if info_prepath is given + util.write(metrics, f'{info_prepath}_session_metrics_{df_mode}.pkl') + return metrics
+ + +
[docs]def calc_trial_metrics(session_metrics_list, info_prepath=None): + ''' + Calculate the trial metrics: mean(strength), mean(efficiency), mean(stability), consistency + @param list:session_metrics_list The metrics collected from each session; format: {session_index: {'scalar': {...}, 'local': {...}}} + @param str:info_prepath Optional info_prepath to auto-save the output to + @returns dict:metrics Consists of scalar metrics and series local metrics + ''' + # calculate mean of session metrics + mean_return_list = [sm['local']['mean_return'] for sm in session_metrics_list] + mean_length_list = [sm['local']['mean_length'] for sm in session_metrics_list] + mean_success_list = [sm['local']['mean_success'] for sm in session_metrics_list] + frames = session_metrics_list[0]['local']['frames'] + opt_steps = session_metrics_list[0]['local']['opt_steps'] + + # for plotting: gather all local series of sessions + local = { + 'mean_return': mean_return_list, + 'mean_length': mean_length_list, + 'mean_success': mean_success_list, + 'frames': frames, + 'opt_steps': opt_steps, + } + metrics = { + 'local': local, + } + if info_prepath is not None: # auto-save if info_prepath is given + util.write(metrics, f'{info_prepath}_trial_metrics.pkl') + return metrics
+ + +
[docs]def calc_experiment_df(trial_data_dict, info_prepath=None): + '''Collect all trial data (metrics and config) from trials into a dataframe''' + experiment_df = pd.DataFrame(trial_data_dict).transpose() + cols = METRICS_COLS + config_cols = sorted(ps.difference(experiment_df.columns.tolist(), cols)) + sorted_cols = config_cols + cols + experiment_df = experiment_df.reindex(sorted_cols, axis=1) + experiment_df.sort_values(by=['strength'], ascending=False, inplace=True) + if info_prepath is not None: + util.write(experiment_df, f'{info_prepath}_experiment_df.csv') + # save important metrics in info_prepath directly + util.write(experiment_df, f'{info_prepath.replace("info/", "")}_experiment_df.csv') + return experiment_df
+ + +# interface analyze methods + +
[docs]def analyze_session(session_spec, session_df, df_mode): + '''Analyze session and save data, then return metrics. Note there are 2 types of session_df: body.eval_df and body.train_df''' + info_prepath = session_spec['meta']['info_prepath'] + session_df = session_df.copy() + assert len(session_df) > 1, f'Need more than 1 datapoint to calculate metrics' + util.write(session_df, f'{info_prepath}_session_df_{df_mode}.csv') + # calculate metrics + session_metrics = calc_session_metrics(session_df, ps.get(session_spec, 'env.0.name'), info_prepath, df_mode) + # plot graph + viz.plot_session(session_spec, session_metrics, session_df, df_mode) + return session_metrics
+ + +
[docs]def analyze_trial(trial_spec, session_metrics_list): + '''Analyze trial and save data, then return metrics''' + info_prepath = trial_spec['meta']['info_prepath'] + # calculate metrics + trial_metrics = calc_trial_metrics(session_metrics_list, info_prepath) + # plot graphs + viz.plot_trial(trial_spec, trial_metrics) + # zip files + if util.get_lab_mode() == 'train': + predir, _, _, _, _, _ = util.prepath_split(info_prepath) + shutil.make_archive(predir, 'zip', predir) + logger.info(f'All trial data zipped to {predir}.zip') + return trial_metrics
+ + +
[docs]def analyze_experiment(spec, trial_data_dict): + '''Analyze experiment and save data''' + info_prepath = spec['meta']['info_prepath'] + util.write(trial_data_dict, f'{info_prepath}_trial_data_dict.json') + # calculate experiment df + experiment_df = calc_experiment_df(trial_data_dict, info_prepath) + # plot graph + viz.plot_experiment(spec, experiment_df, METRICS_COLS) + # zip files + predir, _, _, _, _, _ = util.prepath_split(info_prepath) + shutil.make_archive(predir, 'zip', predir) + logger.info(f'All experiment data zipped to {predir}.zip') + return experiment_df
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/experiment/retro_analysis.html b/docs/build/html/_modules/convlab/experiment/retro_analysis.html new file mode 100644 index 0000000..eccb890 --- /dev/null +++ b/docs/build/html/_modules/convlab/experiment/retro_analysis.html @@ -0,0 +1,265 @@ + + + + + + + + + + + convlab.experiment.retro_analysis — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.experiment.retro_analysis
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.experiment.retro_analysis

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import os
+# The retro analysis module
+# Runs analysis post-hoc using existing data files
+# example: yarn retro_analyze data/reinforce_cartpole_2018_01_22_211751/
+from glob import glob
+
+import pydash as ps
+
+from convlab.experiment import analysis
+from convlab.lib import logger, util
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]def retro_analyze_sessions(predir): + '''Retro analyze all sessions''' + logger.info('Running retro_analyze_sessions') + session_spec_paths = glob(f'{predir}/*_s*_spec.json') + util.parallelize(_retro_analyze_session, [(p,) for p in session_spec_paths], num_cpus=util.NUM_CPUS)
+ + +def _retro_analyze_session(session_spec_path): + '''Method to retro analyze a single session given only a path to its spec''' + session_spec = util.read(session_spec_path) + info_prepath = session_spec['meta']['info_prepath'] + for df_mode in ('eval', 'train'): + session_df = util.read(f'{info_prepath}_session_df_{df_mode}.csv') + analysis.analyze_session(session_spec, session_df, df_mode) + + +
[docs]def retro_analyze_trials(predir): + '''Retro analyze all trials''' + logger.info('Running retro_analyze_trials') + session_spec_paths = glob(f'{predir}/*_s*_spec.json') + # remove session spec paths + trial_spec_paths = ps.difference(glob(f'{predir}/*_t*_spec.json'), session_spec_paths) + util.parallelize(_retro_analyze_trial, [(p,) for p in trial_spec_paths], num_cpus=util.NUM_CPUS)
+ + +def _retro_analyze_trial(trial_spec_path): + '''Method to retro analyze a single trial given only a path to its spec''' + trial_spec = util.read(trial_spec_path) + meta_spec = trial_spec['meta'] + info_prepath = meta_spec['info_prepath'] + session_metrics_list = [util.read(f'{info_prepath}_s{s}_session_metrics_eval.pkl') for s in range(meta_spec['max_session'])] + analysis.analyze_trial(trial_spec, session_metrics_list) + + +
[docs]def retro_analyze_experiment(predir): + '''Retro analyze an experiment''' + logger.info('Running retro_analyze_experiment') + trial_spec_paths = glob(f'{predir}/*_t*_spec.json') + # remove trial and session spec paths + experiment_spec_paths = ps.difference(glob(f'{predir}/*_spec.json'), trial_spec_paths) + experiment_spec_path = experiment_spec_paths[0] + spec = util.read(experiment_spec_path) + info_prepath = spec['meta']['info_prepath'] + if os.path.exists(f'{info_prepath}_trial_data_dict.json'): + return # only run analysis if experiment had been ran + trial_data_dict = util.read(f'{info_prepath}_trial_data_dict.json') + analysis.analyze_experiment(spec, trial_data_dict)
+ + +
[docs]def retro_analyze(predir): + ''' + Method to analyze experiment/trial from files after it ran. + @example + + yarn retro_analyze data/reinforce_cartpole_2018_01_22_211751/ + ''' + predir = predir.strip('/') # sanitary + os.environ['LOG_PREPATH'] = f'{predir}/log/retro_analyze' # to prevent overwriting log file + logger.info(f'Running retro-analysis on {predir}') + retro_analyze_sessions(predir) + retro_analyze_trials(predir) + retro_analyze_experiment(predir) + logger.info('Finished retro-analysis')
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/human_eval/analysis.html b/docs/build/html/_modules/convlab/human_eval/analysis.html new file mode 100644 index 0000000..cf8a310 --- /dev/null +++ b/docs/build/html/_modules/convlab/human_eval/analysis.html @@ -0,0 +1,265 @@ + + + + + + + + + + + convlab.human_eval.analysis — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.human_eval.analysis
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.human_eval.analysis

+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import argparse
+import json
+import os
+
+import numpy as np
+
+
+
[docs]def main(): + """This task consists of an MTurk agent evaluating a chit-chat model. They + are asked to chat to the model adopting a specific persona. After their + conversation, they are asked to evaluate their partner on several metrics. + """ + parser = argparse.ArgumentParser(description='Analyze MEDkit experiment results') + parser.add_argument( + '-dp', '--datapath', default='./', + help='path to datasets, defaults to current directory') + + args = parser.parse_args() + + dirs = os.listdir(args.datapath) + + num_s_dials = 0 + num_f_dials = 0 + dial_lens = [] + usr_turn_lens = [] + sys_turn_lens = [] + u_scores = [] + a_scores = [] + num_domains = [] + num_s_per_level = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0} + num_f_per_level = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0} + for file in dirs: + if os.path.isfile(os.path.join(args.datapath, file)) and file.endswith('json'): + # print('open', os.path.join(args.datapath, file)) + with open(os.path.join(args.datapath, file)) as f: + print(file) + result = json.load(f) + # pprint(result) + level = len(result['goal']['domain_ordering']) + if level > 5: + level = 5 + num_domains.append(level) + if result['success']: + num_s_dials += 1 + num_s_per_level[level] += 1 + else: + num_f_dials += 1 + num_f_per_level[level] += 1 + dial_lens.append(len(result['dialog'])) + usr_lens = [] + sys_lens = [] + for i, turn in enumerate(result['dialog']): + if turn[0] == 0: + usr_turn_lens.append(len(turn[1].split())) + elif turn[0] == 1: + sys_turn_lens.append(len(turn[1].split())) + u_scores.append(result['understanding_score']) + a_scores.append(result['appropriateness_score']) + print('Total number of dialogs:', num_s_dials + num_f_dials) + print('Success rate:', num_s_dials/(num_s_dials + num_f_dials)) + for level in num_s_per_level: + s_rate = 0 if num_s_per_level[level] + num_f_per_level[level] == 0 else\ + num_s_per_level[level] / (num_s_per_level[level] + num_f_per_level[level]) + print('Level {} success rate: {}'.format(level, s_rate)) + print('Avg dialog length: {}(+-{})'.format(np.mean(dial_lens), np.std(dial_lens))) + print('Avg turn length: {}(+-{})'.format(np.mean(usr_turn_lens+sys_turn_lens), np.std(usr_turn_lens+sys_turn_lens))) + print('Avg user turn length: {}(+-{})'.format(np.mean(usr_turn_lens), np.std(usr_turn_lens))) + print('Avg system turn length: {}(+-{})'.format(np.mean(sys_turn_lens), np.std(sys_turn_lens))) + print('Avg number of domains: {}(+-{})'.format(np.mean(num_domains), np.std(num_domains))) + print('Avg understanding score: {}(+-{})'.format(np.mean(u_scores), np.std(u_scores))) + print('Avg appropriateness score: {}(+-{})'.format(np.mean(a_scores), np.std(a_scores)))
+ + +if __name__ == '__main__': + main() +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/human_eval/sequicity_server.html b/docs/build/html/_modules/convlab/human_eval/sequicity_server.html new file mode 100644 index 0000000..880651a --- /dev/null +++ b/docs/build/html/_modules/convlab/human_eval/sequicity_server.html @@ -0,0 +1,249 @@ + + + + + + + + + + + convlab.human_eval.sequicity_server — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.human_eval.sequicity_server
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.human_eval.sequicity_server

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from pprint import pprint
+from queue import PriorityQueue
+from threading import Thread
+
+from flask import Flask, request, jsonify
+
+from convlab.modules.e2e.multiwoz.Sequicity.model import main as sequicity_load
+
+rgi_queue = PriorityQueue(maxsize=0)
+rgo_queue = PriorityQueue(maxsize=0)
+
+app = Flask(__name__)
+
+
+
[docs]@app.route('/', methods=['GET', 'POST']) +def process(): + try: + in_request = request.json + print(in_request) + except: + return "invalid input: {}".format(in_request) + rgi_queue.put(in_request) + rgi_queue.join() + output = rgo_queue.get() + print(output['response']) + rgo_queue.task_done() + return jsonify(output)
+ + +
[docs]def generate_response(in_queue, out_queue): + # Load Sequicity model + sequicity = sequicity_load('load', 'tsdf-multiwoz') + + while True: + # pop input + in_request = in_queue.get() + state = in_request['state'] + input = in_request['input'] + pprint(in_request) + try: + state = sequicity.predict(input, state) + except Exception as e: + print('State update error', e) + state = {} + pprint(state) + try: + response = state['sys'] + except Exception as e: + print('Response generation error', e) + response = 'What did you say?' + out_queue.put({'response': response, 'state': state}) + in_queue.task_done() + out_queue.join()
+ + +if __name__ == '__main__': + worker = Thread(target=generate_response, args=(rgi_queue, rgo_queue,)) + worker.setDaemon(True) + worker.start() + + app.run(host='0.0.0.0', port=10001) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/lib/decorator.html b/docs/build/html/_modules/convlab/lib/decorator.html new file mode 100644 index 0000000..cf2aee2 --- /dev/null +++ b/docs/build/html/_modules/convlab/lib/decorator.html @@ -0,0 +1,232 @@ + + + + + + + + + + + convlab.lib.decorator — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.lib.decorator

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import time
+from functools import wraps
+
+from convlab.lib import logger
+
+logger = logger.get_logger(__name__)
+
+
+
[docs]def lab_api(fn): + ''' + Function decorator to label and check Lab API methods + @example + + from convlab.lib.decorator import lab_api + @lab_api + def foo(): + print('foo') + ''' + return fn
+ + +
[docs]def timeit(fn): + ''' + Function decorator to measure execution time + @example + + from convlab.lib.decorator import timeit + @timeit + def foo(sec): + time.sleep(sec) + print('foo') + + foo(1) + # => foo + # => Timed: foo 1000.9971ms + ''' + @wraps(fn) + def time_fn(*args, **kwargs): + start = time.time() + output = fn(*args, **kwargs) + end = time.time() + logger.debug(f'Timed: {fn.__name__} {round((end - start) * 1000, 4)}ms') + return output + return time_fn
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/lib/distribution.html b/docs/build/html/_modules/convlab/lib/distribution.html new file mode 100644 index 0000000..c1d7c52 --- /dev/null +++ b/docs/build/html/_modules/convlab/lib/distribution.html @@ -0,0 +1,275 @@ + + + + + + + + + + + convlab.lib.distribution — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.lib.distribution
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.lib.distribution

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import torch
+# Custom PyTorch distribution classes to be registered in policy_util.py
+# Mainly used by policy_util action distribution
+from torch import distributions
+
+
+
[docs]class Argmax(distributions.Categorical): + ''' + Special distribution class for argmax sampling, where probability is always 1 for the argmax. + NOTE although argmax is not a sampling distribution, this implementation is for API consistency. + ''' + + def __init__(self, probs=None, logits=None, validate_args=None): + if probs is not None: + new_probs = torch.zeros_like(probs, dtype=torch.float) + new_probs[probs == probs.max(dim=-1, keepdim=True)[0]] = 1.0 + probs = new_probs + elif logits is not None: + new_logits = torch.full_like(logits, -1e8, dtype=torch.float) + new_logits[logits == logits.max(dim=-1, keepdim=True)[0]] = 1.0 + logits = new_logits + + super().__init__(probs=probs, logits=logits, validate_args=validate_args)
+ + +
[docs]class GumbelCategorical(distributions.Categorical): + ''' + Special Categorical using Gumbel distribution to simulate softmax categorical for discrete action. + Similar to OpenAI's https://github.com/openai/baselines/blob/98257ef8c9bd23a24a330731ae54ed086d9ce4a7/baselines/a2c/utils.py#L8-L10 + Explanation http://amid.fish/assets/gumbel.html + ''' + +
[docs] def sample(self, sample_shape=torch.Size()): + '''Gumbel softmax sampling''' + u = torch.empty(self.logits.size(), device=self.logits.device, dtype=self.logits.dtype).uniform_(0, 1) + noisy_logits = self.logits - torch.log(-torch.log(u)) + return torch.argmax(noisy_logits, dim=0)
+ + +
[docs]class MultiCategorical(distributions.Categorical): + '''MultiCategorical as collection of Categoricals''' + + def __init__(self, probs=None, logits=None, validate_args=None): + self.categoricals = [] + if probs is None: + probs = [None] * len(logits) + elif logits is None: + logits = [None] * len(probs) + else: + raise ValueError('Either probs or logits must be None') + + for sub_probs, sub_logits in zip(probs, logits): + categorical = distributions.Categorical(probs=sub_probs, logits=sub_logits, validate_args=validate_args) + self.categoricals.append(categorical) + + @property + def logits(self): + return [cat.logits for cat in self.categoricals] + + @property + def probs(self): + return [cat.probs for cat in self.categoricals] + + @property + def param_shape(self): + return [cat.param_shape for cat in self.categoricals] + + @property + def mean(self): + return torch.stack([cat.mean for cat in self.categoricals]) + + @property + def variance(self): + return torch.stack([cat.variance for cat in self.categoricals]) + +
[docs] def sample(self, sample_shape=torch.Size()): + return torch.stack([cat.sample(sample_shape=sample_shape) for cat in self.categoricals])
+ +
[docs] def log_prob(self, value): + value_t = value.transpose(0, 1) + return torch.stack([cat.log_prob(value_t[idx]) for idx, cat in enumerate(self.categoricals)])
+ +
[docs] def entropy(self): + return torch.stack([cat.entropy() for cat in self.categoricals])
+ +
[docs] def enumerate_support(self): + return [cat.enumerate_support() for cat in self.categoricals]
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/lib/file_util.html b/docs/build/html/_modules/convlab/lib/file_util.html new file mode 100644 index 0000000..a8e248f --- /dev/null +++ b/docs/build/html/_modules/convlab/lib/file_util.html @@ -0,0 +1,198 @@ + + + + + + + + + + + convlab.lib.file_util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.lib.file_util

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from pathlib import Path
+
+from allennlp.common.file_utils import cached_path as allennlp_cached_path
+
+
+
[docs]def cached_path(file_path, cached_dir=None): + if not cached_dir: + cached_dir = str(Path(Path.home() / '.convlab') / "cache") + + return allennlp_cached_path(file_path, cached_dir)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/lib/logger.html b/docs/build/html/_modules/convlab/lib/logger.html new file mode 100644 index 0000000..c7b7384 --- /dev/null +++ b/docs/build/html/_modules/convlab/lib/logger.html @@ -0,0 +1,314 @@ + + + + + + + + + + + convlab.lib.logger — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.lib.logger

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import logging
+import os
+import sys
+import warnings
+
+import colorlog
+import pandas as pd
+
+
+
[docs]class FixedList(list): + '''fixed-list to restrict addition to root logger handler''' + +
[docs] def append(self, e): + pass
+ +NEW_LVLS = {'NL': 17, 'ACT': 14, 'STATE': 13} +for name, val in NEW_LVLS.items(): + logging.addLevelName(val, name) + setattr(logging, name, val) + +LOG_FORMAT = '[%(asctime)s PID:%(process)d %(levelname)s %(filename)s %(funcName)s] %(message)s' +color_formatter = colorlog.ColoredFormatter('%(log_color)s[%(asctime)s PID:%(process)d %(levelname)s %(filename)s %(funcName)s]%(reset)s %(message)s', +log_colors={ + 'DEBUG': 'cyan', + 'NL': 'cyan', + 'ACT': 'cyan', + 'STATE': 'cyan', + 'INFO': 'green', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'red,bg_white'}) +sh = logging.StreamHandler(sys.stdout) +sh.setFormatter(color_formatter) +lab_logger = logging.getLogger() +lab_logger.handlers = FixedList([sh]) +logging.getLogger('ray').propagate = False # hack to mute poorly designed ray TF warning log + +# this will trigger from Experiment init on reload(logger) +if os.environ.get('LOG_PREPATH') is not None: + warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning) + + log_filepath = os.environ['LOG_PREPATH'] + '.log' + os.makedirs(os.path.dirname(log_filepath), exist_ok=True) + # create file handler + formatter = logging.Formatter(LOG_FORMAT) + fh = logging.FileHandler(log_filepath) + fh.setFormatter(formatter) + # add stream and file handler + lab_logger.handlers = FixedList([sh, fh]) + +if os.environ.get('LOG_LEVEL'): + lab_logger.setLevel(os.environ['LOG_LEVEL']) +else: + lab_logger.setLevel('INFO') + + +
[docs]def set_level(lvl): + lab_logger.setLevel(lvl) + os.environ['LOG_LEVEL'] = lvl
+ + +
[docs]def critical(msg, *args, **kwargs): + return lab_logger.critical(msg, *args, **kwargs)
+ + +
[docs]def debug(msg, *args, **kwargs): + return lab_logger.debug(msg, *args, **kwargs)
+ + +
[docs]def error(msg, *args, **kwargs): + return lab_logger.error(msg, *args, **kwargs)
+ + +
[docs]def exception(msg, *args, **kwargs): + return lab_logger.exception(msg, *args, **kwargs)
+ + +
[docs]def info(msg, *args, **kwargs): + return lab_logger.info(msg, *args, **kwargs)
+ + +
[docs]def warning(msg, *args, **kwargs): + return lab_logger.warning(msg, *args, **kwargs)
+ + +
[docs]def nl(msg, *args, **kwargs): + return lab_logger.log(NEW_LVLS['NL'], msg, *args, **kwargs)
+ + +
[docs]def act(msg, *args, **kwargs): + return lab_logger.log(NEW_LVLS['ACT'], msg, *args, **kwargs)
+ + +
[docs]def state(msg, *args, **kwargs): + return lab_logger.log(NEW_LVLS['STATE'], msg, *args, **kwargs)
+ + +
[docs]def get_logger(__name__): + '''Create a child logger specific to a module''' + module_logger = logging.getLogger(__name__) + + def nl(msg, *args, **kwargs): + return module_logger.log(NEW_LVLS['NL'], msg, *args, **kwargs) + + def act(msg, *args, **kwargs): + return module_logger.log(NEW_LVLS['ACT'], msg, *args, **kwargs) + + def state(msg, *args, **kwargs): + return module_logger.log(NEW_LVLS['STATE'], msg, *args, **kwargs) + + setattr(module_logger, 'nl', nl) + setattr(module_logger, 'act', act) + setattr(module_logger, 'state', state) + + return module_logger
+ + +
[docs]def toggle_debug(modules, level='DEBUG'): + '''Turn on module-specific debugging using their names, e.g. algorithm, actor_critic, at the desired debug level.''' + logger_names = list(logging.Logger.manager.loggerDict.keys()) + for module in modules: + name = module.strip() + for logger_name in logger_names: + if name in logger_name.split('.'): + module_logger = logging.getLogger(logger_name) + module_logger.setLevel(getattr(logging, level))
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/lib/math_util.html b/docs/build/html/_modules/convlab/lib/math_util.html new file mode 100644 index 0000000..0c23c2b --- /dev/null +++ b/docs/build/html/_modules/convlab/lib/math_util.html @@ -0,0 +1,346 @@ + + + + + + + + + + + convlab.lib.math_util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.lib.math_util

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+# Various math calculations used by algorithms
+import numpy as np
+import torch
+
+
+# general math methods
+
+
[docs]def normalize(v): + '''Method to normalize a rank-1 np array''' + v_min = v.min() + v_max = v.max() + v_range = v_max - v_min + v_range += 1e-08 # division guard + v_norm = (v - v_min) / v_range + return v_norm
+ + +
[docs]def standardize(v): + '''Method to standardize a rank-1 np array''' + # assert len(v) > 1, 'Cannot standardize vector of size 1' + if len(v) == 1: + return v + + v_std = (v - v.mean()) / (v.std() + 1e-08) + return v_std
+ + +
[docs]def to_one_hot(data, max_val): + '''Convert an int list of data into one-hot vectors''' + return np.eye(max_val)[np.array(data)]
+ + +
[docs]def venv_pack(batch_tensor, num_envs): + '''Apply the reverse of venv_unpack to pack a batch tensor from (b*num_envs, *shape) to (b, num_envs, *shape)''' + shape = list(batch_tensor.shape) + if len(shape) < 2: # scalar data (b, num_envs,) + return batch_tensor.view(-1, num_envs) + else: # non-scalar data (b, num_envs, *shape) + pack_shape = [-1, num_envs] + shape[1:] + return batch_tensor.view(pack_shape)
+ + +
[docs]def venv_unpack(batch_tensor): + ''' + Unpack a sampled vec env batch tensor + e.g. for a state with original shape (4, ), vec env should return vec state with shape (num_envs, 4) to store in memory + When sampled with batch_size b, we should get shape (b, num_envs, 4). But we need to unpack the num_envs dimension to get (b * num_envs, 4) for passing to a network. This method does that. + ''' + shape = list(batch_tensor.shape) + if len(shape) < 3: # scalar data (b, num_envs,) + return batch_tensor.view(-1) + else: # non-scalar data (b, num_envs, *shape) + unpack_shape = [-1] + shape[2:] + return batch_tensor.view(unpack_shape)
+ + +# Policy Gradient calc +# advantage functions + +
[docs]def calc_returns(rewards, dones, gamma): + ''' + Calculate the simple returns (full rollout) i.e. sum discounted rewards up till termination + ''' + T = len(rewards) + rets = torch.zeros_like(rewards) + future_ret = torch.tensor(0.0, dtype=rewards.dtype) + not_dones = 1 - dones + for t in reversed(range(T)): + rets[t] = future_ret = rewards[t] + gamma * future_ret * not_dones[t] + return rets
+ + +
[docs]def calc_nstep_returns(rewards, dones, next_v_pred, gamma, n): + ''' + Calculate the n-step returns for advantage. Ref: http://www-anw.cs.umass.edu/~barto/courses/cs687/Chapter%207.pdf + Also see Algorithm S3 from A3C paper https://arxiv.org/pdf/1602.01783.pdf for the calculation used below + R^(n)_t = r_{t} + gamma r_{t+1} + ... + gamma^(n-1) r_{t+n-1} + gamma^(n) V(s_{t+n}) + ''' + rets = torch.zeros_like(rewards) + future_ret = next_v_pred + not_dones = 1 - dones + for t in reversed(range(n)): + rets[t] = future_ret = rewards[t] + gamma * future_ret * not_dones[t] + return rets
+ + +
[docs]def calc_gaes(rewards, dones, v_preds, gamma, lam): + ''' + Calculate GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf + v_preds are values predicted for current states, with one last element as the final next_state + delta is defined as r + gamma * V(s') - V(s) in eqn 10 + GAE is defined in eqn 16 + This method computes in torch tensor to prevent unnecessary moves between devices (e.g. GPU tensor to CPU numpy) + NOTE any standardization is done outside of this method + ''' + T = len(rewards) + assert T + 1 == len(v_preds) # v_preds includes states and 1 last next_state + gaes = torch.zeros_like(rewards) + future_gae = torch.tensor(0.0, dtype=rewards.dtype) + # to multiply with not_dones to handle episode boundary (last state has no V(s')) + not_dones = 1 - dones + for t in reversed(range(T)): + delta = rewards[t] + gamma * v_preds[t + 1] * not_dones[t] - v_preds[t] + gaes[t] = future_gae = delta + gamma * lam * not_dones[t] * future_gae + return gaes
+ + +
[docs]def calc_q_value_logits(state_value, raw_advantages): + mean_adv = raw_advantages.mean(dim=-1).unsqueeze(dim=-1) + return state_value + raw_advantages - mean_adv
+ + +# generic variable decay methods + +
[docs]def no_decay(start_val, end_val, start_step, end_step, step): + '''dummy method for API consistency''' + return start_val
+ + +
[docs]def linear_decay(start_val, end_val, start_step, end_step, step): + '''Simple linear decay with annealing''' + if step < start_step: + return start_val + slope = (end_val - start_val) / (end_step - start_step) + val = max(slope * (step - start_step) + start_val, end_val) + return val
+ + +
[docs]def rate_decay(start_val, end_val, start_step, end_step, step, decay_rate=0.9, frequency=20.): + '''Compounding rate decay that anneals in 20 decay iterations until end_step''' + if step < start_step: + return start_val + if step >= end_step: + return end_val + step_per_decay = (end_step - start_step) / frequency + decay_step = (step - start_step) / step_per_decay + val = max(np.power(decay_rate, decay_step) * start_val, end_val) + return val
+ + +
[docs]def periodic_decay(start_val, end_val, start_step, end_step, step, frequency=60.): + ''' + Linearly decaying sinusoid that decays in roughly 10 iterations until explore_anneal_epi + Plot the equation below to see the pattern + suppose sinusoidal decay, start_val = 1, end_val = 0.2, stop after 60 unscaled x steps + then we get 0.2+0.5*(1-0.2)(1 + cos x)*(1-x/60) + ''' + if step < start_step: + return start_val + if step >= end_step: + return end_val + x_freq = frequency + step_per_decay = (end_step - start_step) / x_freq + x = (step - start_step) / step_per_decay + unit = start_val - end_val + val = end_val * 0.5 * unit * (1 + np.cos(x) * (1 - x / x_freq)) + val = max(val, end_val) + return val
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/lib/optimizer.html b/docs/build/html/_modules/convlab/lib/optimizer.html new file mode 100644 index 0000000..72434cb --- /dev/null +++ b/docs/build/html/_modules/convlab/lib/optimizer.html @@ -0,0 +1,291 @@ + + + + + + + + + + + convlab.lib.optimizer — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.lib.optimizer

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+# Custom PyTorch optimizer classes, to be registered in net_util.py
+import math
+
+import torch
+
+
+
[docs]class GlobalAdam(torch.optim.Adam): + ''' + Global Adam algorithm with shared states for Hogwild. + Adapted from https://github.com/ikostrikov/pytorch-a3c/blob/master/my_optim.py (MIT) + ''' + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + super().__init__(params, lr, betas, eps, weight_decay) + + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'] = torch.zeros(1) + state['exp_avg'] = p.data.new().resize_as_(p.data).zero_() + state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() + +
[docs] def share_memory(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'].share_memory_() + state['exp_avg'].share_memory_() + state['exp_avg_sq'].share_memory_()
+ +
[docs] def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + denom = exp_avg_sq.sqrt().add_(group['eps']) + bias_correction1 = 1 - beta1 ** state['step'].item() + bias_correction2 = 1 - beta2 ** state['step'].item() + step_size = group['lr'] * math.sqrt( + bias_correction2) / bias_correction1 + p.data.addcdiv_(-step_size, exp_avg, denom) + return loss
+ + +
[docs]class GlobalRMSprop(torch.optim.RMSprop): + ''' + Global RMSprop algorithm with shared states for Hogwild. + Adapted from https://github.com/jingweiz/pytorch-rl/blob/master/optims/sharedRMSprop.py (MIT) + ''' + + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0): + super().__init__(params, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=0, centered=False) + + # State initialisation (must be done before step, else will not be shared between threads) + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'] = p.data.new().resize_(1).zero_() + state['square_avg'] = p.data.new().resize_as_(p.data).zero_() + +
[docs] def share_memory(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'].share_memory_() + state['square_avg'].share_memory_()
+ +
[docs] def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + square_avg = state['square_avg'] + alpha = group['alpha'] + state['step'] += 1 + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) + avg = square_avg.sqrt().add_(group['eps']) + p.data.addcdiv_(-group['lr'], grad, avg) + return loss
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/lib/util.html b/docs/build/html/_modules/convlab/lib/util.html new file mode 100644 index 0000000..895ef35 --- /dev/null +++ b/docs/build/html/_modules/convlab/lib/util.html @@ -0,0 +1,935 @@ + + + + + + + + + + + convlab.lib.util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.lib.util

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import json
+import operator
+import os
+import pickle
+import subprocess
+import sys
+import time
+from collections import deque
+from contextlib import contextmanager
+from datetime import datetime
+from importlib import reload
+from pprint import pformat
+
+import cv2
+import numpy as np
+import pandas as pd
+import pydash as ps
+import regex as re
+import torch
+import torch.multiprocessing as mp
+import ujson
+import yaml
+
+from convlab import ROOT_DIR, EVAL_MODES
+
+NUM_CPUS = mp.cpu_count()
+FILE_TS_FORMAT = '%Y_%m_%d_%H%M%S'
+RE_FILE_TS = re.compile(r'(\d{4}_\d{2}_\d{2}_\d{6})')
+
+
+
[docs]class LabJsonEncoder(json.JSONEncoder): +
[docs] def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, (np.ndarray, pd.Series)): + return obj.tolist() + else: + return str(obj)
+ + +
[docs]def batch_get(arr, idxs): + '''Get multi-idxs from an array depending if it's a python list or np.array''' + if isinstance(arr, (list, deque)): + return np.array(operator.itemgetter(*idxs)(arr)) + else: + return arr[idxs]
+ + +
[docs]def calc_srs_mean_std(sr_list): + '''Given a list of series, calculate their mean and std''' + cat_df = pd.DataFrame(dict(enumerate(sr_list))) + mean_sr = cat_df.mean(axis=1) + std_sr = cat_df.std(axis=1) + return mean_sr, std_sr
+ + +
[docs]def calc_ts_diff(ts2, ts1): + ''' + Calculate the time from tss ts1 to ts2 + @param {str} ts2 Later ts in the FILE_TS_FORMAT + @param {str} ts1 Earlier ts in the FILE_TS_FORMAT + @returns {str} delta_t in %H:%M:%S format + @example + + ts1 = '2017_10_17_084739' + ts2 = '2017_10_17_084740' + ts_diff = util.calc_ts_diff(ts2, ts1) + # => '0:00:01' + ''' + delta_t = datetime.strptime(ts2, FILE_TS_FORMAT) - datetime.strptime(ts1, FILE_TS_FORMAT) + return str(delta_t)
+ + +
[docs]def cast_df(val): + '''missing pydash method to cast value as DataFrame''' + if isinstance(val, pd.DataFrame): + return val + return pd.DataFrame(val)
+ + +
[docs]def cast_list(val): + '''missing pydash method to cast value as list''' + if ps.is_list(val): + return val + else: + return [val]
+ + +
[docs]def clear_periodic_ckpt(prepath): + '''Clear periodic (with -epi) ckpt files in prepath''' + if '-epi' in prepath: + run_cmd(f'rm {prepath}*')
+ + +
[docs]def concat_batches(batches): + ''' + Concat batch objects from body.memory.sample() into one batch, when all bodies experience similar envs + Also concat any nested epi sub-batches into flat batch + {k: arr1} + {k: arr2} = {k: arr1 + arr2} + ''' + # if is nested, then is episodic + is_episodic = isinstance(batches[0]['dones'][0], (list, np.ndarray)) + concat_batch = {} + for k in batches[0]: + datas = [] + for batch in batches: + data = batch[k] + if is_episodic: # make into plain batch instead of nested + data = np.concatenate(data) + datas.append(data) + concat_batch[k] = np.concatenate(datas) + return concat_batch
+ + +
[docs]def downcast_float32(df): + '''Downcast any float64 col to float32 to allow safer pandas comparison''' + for col in df.columns: + if df[col].dtype == 'float': + df[col] = df[col].astype('float32') + return df
+ + +
[docs]def epi_done(done): + ''' + General method to check if episode is done for both single and vectorized env + Only return True for singleton done since vectorized env does not have a natural episode boundary + ''' + return np.isscalar(done) and done
+ + +
[docs]def find_ckpt(prepath): + '''Find the ckpt-lorem-ipsum in a string and return lorem-ipsum''' + if 'ckpt' in prepath: + ckpt_str = ps.find(prepath.split('_'), lambda s: s.startswith('ckpt')) + ckpt = ckpt_str.replace('ckpt-', '') + else: + ckpt = None + return ckpt
+ + +
[docs]def frame_mod(frame, frequency, num_envs): + ''' + Generic mod for (frame % frequency == 0) for when num_envs is 1 or more, + since frame will increase multiple ticks for vector env, use the remainder''' + remainder = num_envs or 1 + return (frame % frequency < remainder)
+ + +
[docs]def flatten_dict(obj, delim='.'): + '''Missing pydash method to flatten dict''' + nobj = {} + for key, val in obj.items(): + if ps.is_dict(val) and not ps.is_empty(val): + strip = flatten_dict(val, delim) + for k, v in strip.items(): + nobj[key + delim + k] = v + elif ps.is_list(val) and not ps.is_empty(val) and ps.is_dict(val[0]): + for idx, v in enumerate(val): + nobj[key + delim + str(idx)] = v + if ps.is_object(v): + nobj = flatten_dict(nobj, delim) + else: + nobj[key] = val + return nobj
+ + +
[docs]def get_class_name(obj, lower=False): + '''Get the class name of an object''' + class_name = obj.__class__.__name__ + if lower: + class_name = class_name.lower() + return class_name
+ + +
[docs]def get_class_attr(obj): + '''Get the class attr of an object as dict''' + attr_dict = {} + for k, v in obj.__dict__.items(): + if hasattr(v, '__dict__') or ps.is_tuple(v): + val = str(v) + else: + val = v + attr_dict[k] = val + return attr_dict
+ + +
[docs]def get_file_ext(data_path): + '''get the `.ext` of file.ext''' + return os.path.splitext(data_path)[-1]
+ + +
[docs]def get_fn_list(a_cls): + ''' + Get the callable, non-private functions of a class + @returns {[*str]} A list of strings of fn names + ''' + fn_list = ps.filter_(dir(a_cls), lambda fn: not fn.endswith('__') and callable(getattr(a_cls, fn))) + return fn_list
+ + +
[docs]def get_git_sha(): + return subprocess.check_output(['git', 'rev-parse', 'HEAD'], close_fds=True, cwd=ROOT_DIR).decode().strip()
+ + +
[docs]def get_lab_mode(): + return os.environ.get('lab_mode')
+ + +
[docs]def get_prepath(spec, unit='experiment'): + spec_name = spec['name'] + meta_spec = spec['meta'] + predir = f'output/{spec_name}_{meta_spec["experiment_ts"]}' + prename = f'{spec_name}' + trial_index = meta_spec['trial'] + session_index = meta_spec['session'] + t_str = '' if trial_index is None else f'_t{trial_index}' + s_str = '' if session_index is None else f'_s{session_index}' + if unit == 'trial': + prename += t_str + elif unit == 'session': + prename += f'{t_str}{s_str}' + ckpt = meta_spec['ckpt'] + if ckpt is not None: + prename += f'_ckpt-{ckpt}' + prepath = f'{predir}/{prename}' + return prepath
+ + +
[docs]def get_ts(pattern=FILE_TS_FORMAT): + ''' + Get current ts, defaults to format used for filename + @param {str} pattern To format the ts + @returns {str} ts + @example + + util.get_ts() + # => '2017_10_17_084739' + ''' + ts_obj = datetime.now() + ts = ts_obj.strftime(pattern) + assert RE_FILE_TS.search(ts) + return ts
+ + +
[docs]def insert_folder(prepath, folder): + '''Insert a folder into prepath''' + split_path = prepath.split('/') + prename = split_path.pop() + split_path += [folder, prename] + return '/'.join(split_path)
+ + +
[docs]def in_eval_lab_modes(): + '''Check if lab_mode is one of EVAL_MODES''' + return get_lab_mode() in EVAL_MODES
+ + +
[docs]def is_jupyter(): + '''Check if process is in Jupyter kernel''' + try: + get_ipython().config + return True + except NameError: + return False + return False
+ + +
[docs]@contextmanager +def ctx_lab_mode(lab_mode): + ''' + Creates context to run method with a specific lab_mode + @example + with util.ctx_lab_mode('eval'): + foo() + + @util.ctx_lab_mode('eval') + def foo(): + ... + ''' + prev_lab_mode = os.environ.get('lab_mode') + os.environ['lab_mode'] = lab_mode + yield + if prev_lab_mode is None: + del os.environ['lab_mode'] + else: + os.environ['lab_mode'] = prev_lab_mode
+ + +
[docs]def monkey_patch(base_cls, extend_cls): + '''Monkey patch a base class with methods from extend_cls''' + ext_fn_list = get_fn_list(extend_cls) + for fn in ext_fn_list: + setattr(base_cls, fn, getattr(extend_cls, fn))
+ + +
[docs]def parallelize(fn, args, num_cpus=NUM_CPUS): + ''' + Parallelize a method fn, args and return results with order preserved per args. + args should be a list of tuples. + @returns {list} results Order preserved output from fn. + ''' + pool = mp.Pool(num_cpus, maxtasksperchild=1) + results = pool.starmap(fn, args) + pool.close() + pool.join() + return results
+ + +
[docs]def prepath_split(prepath): + ''' + Split prepath into useful names. Works with predir (prename will be None) + prepath: output/dqn_pong_2018_12_02_082510/dqn_pong_t0_s0 + predir: output/dqn_pong_2018_12_02_082510 + prefolder: dqn_pong_2018_12_02_082510 + prename: dqn_pong_t0_s0 + spec_name: dqn_pong + experiment_ts: 2018_12_02_082510 + ckpt: ckpt-best of dqn_pong_t0_s0_ckpt-best if available + ''' + prepath = prepath.strip('_') + tail = prepath.split('output/')[-1] + ckpt = find_ckpt(tail) + if ckpt is not None: # separate ckpt + tail = tail.replace(f'_ckpt-{ckpt}', '') + if '/' in tail: # tail = prefolder/prename + prefolder, prename = tail.split('/', 1) + else: + prefolder, prename = tail, None + predir = f'output/{prefolder}' + spec_name = RE_FILE_TS.sub('', prefolder).strip('_') + experiment_ts = RE_FILE_TS.findall(prefolder)[0] + return predir, prefolder, prename, spec_name, experiment_ts, ckpt
+ + +
[docs]def prepath_to_idxs(prepath): + '''Extract trial index and session index from prepath if available''' + _, _, prename, spec_name, _, _ = prepath_split(prepath) + idxs_tail = prename.replace(spec_name, '').strip('_') + idxs_strs = ps.compact(idxs_tail.split('_')[:2]) + if ps.is_empty(idxs_strs): + return None, None + tidx = idxs_strs[0] + assert tidx.startswith('t') + trial_index = int(tidx.strip('t')) + if len(idxs_strs) == 1: # has session + session_index = None + else: + sidx = idxs_strs[1] + assert sidx.startswith('s') + session_index = int(sidx.strip('s')) + return trial_index, session_index
+ + +
[docs]def prepath_to_spec(prepath): + ''' + Given a prepath, read the correct spec recover the meta_spec that will return the same prepath for eval lab modes + example: output/a2c_cartpole_2018_06_13_220436/a2c_cartpole_t0_s0 + ''' + predir, _, prename, _, experiment_ts, ckpt = prepath_split(prepath) + sidx_res = re.search('_s\d+', prename) + if sidx_res: # replace the _s0 if any + prename = prename.replace(sidx_res[0], '') + spec_path = f'{predir}/{prename}_spec.json' + # read the spec of prepath + spec = read(spec_path) + # recover meta_spec + trial_index, session_index = prepath_to_idxs(prepath) + meta_spec = spec['meta'] + meta_spec['experiment_ts'] = experiment_ts + meta_spec['ckpt'] = ckpt + meta_spec['experiment'] = 0 + meta_spec['trial'] = trial_index + meta_spec['session'] = session_index + check_prepath = get_prepath(spec, unit='session') + assert check_prepath in prepath, f'{check_prepath}, {prepath}' + return spec
+ + +
[docs]def read(data_path, **kwargs): + ''' + Universal data reading method with smart data parsing + - {.csv} to DataFrame + - {.json} to dict, list + - {.yml} to dict + - {*} to str + @param {str} data_path The data path to read from + @returns {data} The read data in sensible format + @example + + data_df = util.read('test/fixture/lib/util/test_df.csv') + # => <DataFrame> + + data_dict = util.read('test/fixture/lib/util/test_dict.json') + data_dict = util.read('test/fixture/lib/util/test_dict.yml') + # => <dict> + + data_list = util.read('test/fixture/lib/util/test_list.json') + # => <list> + + data_str = util.read('test/fixture/lib/util/test_str.txt') + # => <str> + ''' + data_path = smart_path(data_path) + try: + assert os.path.isfile(data_path) + except AssertionError: + raise FileNotFoundError(data_path) + ext = get_file_ext(data_path) + if ext == '.csv': + data = read_as_df(data_path, **kwargs) + elif ext == '.pkl': + data = read_as_pickle(data_path, **kwargs) + else: + data = read_as_plain(data_path, **kwargs) + return data
+ + +
[docs]def read_as_df(data_path, **kwargs): + '''Submethod to read data as DataFrame''' + ext = get_file_ext(data_path) + data = pd.read_csv(data_path, **kwargs) + return data
+ + +
[docs]def read_as_pickle(data_path, **kwargs): + '''Submethod to read data as pickle''' + with open(data_path, 'rb') as f: + data = pickle.load(f) + return data
+ + +
[docs]def read_as_plain(data_path, **kwargs): + '''Submethod to read data as plain type''' + open_file = open(data_path, 'r') + ext = get_file_ext(data_path) + if ext == '.json': + data = ujson.load(open_file, **kwargs) + elif ext == '.yml': + data = yaml.load(open_file, **kwargs) + else: + data = open_file.read() + open_file.close() + return data
+ + +
[docs]def run_cmd(cmd): + '''Run shell command''' + print(f'+ {cmd}') + proc = subprocess.Popen(cmd, cwd=ROOT_DIR, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) + return proc
+ + +
[docs]def run_cmd_wait(proc): + '''Wait on a running process created by util.run_cmd and print its stdout''' + for line in proc.stdout: + print(line.decode(), end='') + output = proc.communicate()[0] + if proc.returncode != 0: + raise subprocess.CalledProcessError(proc.args, proc.returncode, output) + else: + return output
+ + +
[docs]def self_desc(cls): + '''Method to get self description, used at init.''' + desc_list = [f'{get_class_name(cls)}:'] + for k, v in get_class_attr(cls).items(): + if k == 'spec': + desc_v = v['name'] + elif ps.is_dict(v) or ps.is_dict(ps.head(v)): + desc_v = pformat(v) + else: + desc_v = v + desc_list.append(f'- {k} = {desc_v}') + desc = '\n'.join(desc_list) + return desc
+ + +
[docs]def set_attr(obj, attr_dict, keys=None): + '''Set attribute of an object from a dict''' + if keys is not None: + attr_dict = ps.pick(attr_dict, keys) + for attr, val in attr_dict.items(): + setattr(obj, attr, val) + return obj
+ + +
[docs]def set_cuda_id(spec): + '''Use trial and session id to hash and modulo cuda device count for a cuda_id to maximize device usage. Sets the net_spec for the base Net class to pick up.''' + # Don't trigger any cuda call if not using GPU. Otherwise will break multiprocessing on machines with CUDA. + # see issues https://github.com/pytorch/pytorch/issues/334 https://github.com/pytorch/pytorch/issues/3491 https://github.com/pytorch/pytorch/issues/9996 + for agent_spec in spec['agent']: + if 'net' not in agent_spec or not agent_spec['net'].get('gpu'): + return + meta_spec = spec['meta'] + trial_idx = meta_spec['trial'] or 0 + session_idx = meta_spec['session'] or 0 + if meta_spec['distributed'] == 'shared': # shared hogwild uses only global networks, offset them to idx 0 + session_idx = 0 + job_idx = trial_idx * meta_spec['max_session'] + session_idx + job_idx += meta_spec['cuda_offset'] + device_count = torch.cuda.device_count() + cuda_id = None if not device_count else job_idx % device_count + + for agent_spec in spec['agent']: + agent_spec['net']['cuda_id'] = cuda_id
+ + +
[docs]def set_logger(spec, logger, unit=None): + '''Set the logger for a lab unit give its spec''' + os.environ['LOG_PREPATH'] = insert_folder(get_prepath(spec, unit=unit), 'log') + reload(logger) # to set session-specific logger
+ + +
[docs]def set_random_seed(spec): + '''Generate and set random seed for relevant modules, and record it in spec.meta.random_seed''' + torch.set_num_threads(1) # prevent multithread slowdown, set again for hogwild + trial = spec['meta']['trial'] + session = spec['meta']['session'] + random_seed = int(1e5 * (trial or 0) + 1e3 * (session or 0) + time.time()) + torch.cuda.manual_seed_all(random_seed) + torch.manual_seed(random_seed) + np.random.seed(random_seed) + spec['meta']['random_seed'] = random_seed + return random_seed
+ + +def _sizeof(obj, seen=None): + '''Recursively finds size of objects''' + size = sys.getsizeof(obj) + if seen is None: + seen = set() + obj_id = id(obj) + if obj_id in seen: + return 0 + # Important mark as seen *before* entering recursion to gracefully handle + # self-referential objects + seen.add(obj_id) + if isinstance(obj, dict): + size += sum([_sizeof(v, seen) for v in obj.values()]) + size += sum([_sizeof(k, seen) for k in obj.keys()]) + elif hasattr(obj, '__dict__'): + size += _sizeof(obj.__dict__, seen) + elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): + size += sum([_sizeof(i, seen) for i in obj]) + return size + + +
[docs]def sizeof(obj, divisor=1e6): + '''Return the size of object, in MB by default''' + return _sizeof(obj) / divisor
+ + +
[docs]def smart_path(data_path, as_dir=False): + ''' + Resolve data_path into abspath with fallback to join from ROOT_DIR + @param {str} data_path The input data path to resolve + @param {bool} as_dir Whether to return as dirname + @returns {str} The normalized absolute data_path + @example + + util.smart_path('convlab/lib') + # => '/Users/ANON/Documents/convlab/convlab/lib' + + util.smart_path('/tmp') + # => '/tmp' + ''' + if not os.path.isabs(data_path): + abs_path = os.path.abspath(data_path) + if os.path.exists(abs_path): + data_path = abs_path + else: + data_path = os.path.join(ROOT_DIR, data_path) + if as_dir: + data_path = os.path.dirname(data_path) + return os.path.normpath(data_path)
+ + +
[docs]def split_minibatch(batch, mb_size): + '''Split a batch into minibatches of mb_size or smaller, without replacement''' + size = len(batch['rewards']) + assert mb_size < size, f'Minibatch size {mb_size} must be < batch size {size}' + idxs = np.arange(size) + np.random.shuffle(idxs) + chunks = int(size / mb_size) + nested_idxs = np.array_split(idxs, chunks) + mini_batches = [] + for minibatch_idxs in nested_idxs: + minibatch = {k: v[minibatch_idxs] for k, v in batch.items()} + mini_batches.append(minibatch) + return mini_batches
+ + +
[docs]def to_json(d, indent=2): + '''Shorthand method for stringify JSON with indent''' + return json.dumps(d, indent=indent, cls=LabJsonEncoder)
+ + +
[docs]def to_render(): + return get_lab_mode() in ('dev', 'enjoy') and os.environ.get('RENDER', 'true') == 'true'
+ + +
[docs]def to_torch_batch(batch, device, is_episodic): + '''Mutate a batch (dict) to make its values from numpy into PyTorch tensor''' + for k in batch: + if is_episodic: # for episodic format + batch[k] = np.concatenate(batch[k]) + elif ps.is_list(batch[k]): + batch[k] = np.array(batch[k]) + batch[k] = torch.from_numpy(batch[k].astype(np.float32)).to(device) + return batch
+ + +
[docs]def write(data, data_path): + ''' + Universal data writing method with smart data parsing + - {.csv} from DataFrame + - {.json} from dict, list + - {.yml} from dict + - {*} from str(*) + @param {*} data The data to write + @param {str} data_path The data path to write to + @returns {data_path} The data path written to + @example + + data_path = util.write(data_df, 'test/fixture/lib/util/test_df.csv') + + data_path = util.write(data_dict, 'test/fixture/lib/util/test_dict.json') + data_path = util.write(data_dict, 'test/fixture/lib/util/test_dict.yml') + + data_path = util.write(data_list, 'test/fixture/lib/util/test_list.json') + + data_path = util.write(data_str, 'test/fixture/lib/util/test_str.txt') + ''' + data_path = smart_path(data_path) + data_dir = os.path.dirname(data_path) + os.makedirs(data_dir, exist_ok=True) + ext = get_file_ext(data_path) + if ext == '.csv': + write_as_df(data, data_path) + elif ext == '.pkl': + write_as_pickle(data, data_path) + else: + write_as_plain(data, data_path) + return data_path
+ + +
[docs]def write_as_df(data, data_path): + '''Submethod to write data as DataFrame''' + df = cast_df(data) + ext = get_file_ext(data_path) + df.to_csv(data_path, index=False) + return data_path
+ + +
[docs]def write_as_pickle(data, data_path): + '''Submethod to write data as pickle''' + with open(data_path, 'wb') as f: + pickle.dump(data, f) + return data_path
+ + +
[docs]def write_as_plain(data, data_path): + '''Submethod to write data as plain type''' + open_file = open(data_path, 'w') + ext = get_file_ext(data_path) + if ext == '.json': + json.dump(data, open_file, indent=2, cls=LabJsonEncoder) + elif ext == '.yml': + yaml.dump(data, open_file) + else: + open_file.write(str(data)) + open_file.close() + return data_path
+ + +# Atari image preprocessing + + +
[docs]def to_opencv_image(im): + '''Convert to OpenCV image shape h,w,c''' + shape = im.shape + if len(shape) == 3 and shape[0] < shape[-1]: + return im.transpose(1, 2, 0) + else: + return im
+ + +
[docs]def to_pytorch_image(im): + '''Convert to PyTorch image shape c,h,w''' + shape = im.shape + if len(shape) == 3 and shape[-1] < shape[0]: + return im.transpose(2, 0, 1) + else: + return im
+ + +
[docs]def grayscale_image(im): + return cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
+ + +
[docs]def resize_image(im, w_h): + return cv2.resize(im, w_h, interpolation=cv2.INTER_AREA)
+ + +
[docs]def normalize_image(im): + '''Normalizing image by dividing max value 255''' + # NOTE: beware in its application, may cause loss to be 255 times lower due to smaller input values + return np.divide(im, 255.0)
+ + +
[docs]def preprocess_image(im): + ''' + Image preprocessing using OpenAI Baselines method: grayscale, resize + This resize uses stretching instead of cropping + ''' + im = to_opencv_image(im) + im = grayscale_image(im) + im = resize_image(im, (84, 84)) + im = np.expand_dims(im, 0) + return im
+ + +
[docs]def debug_image(im): + ''' + Renders an image for debugging; pauses process until key press + Handles tensor/numpy and conventions among libraries + ''' + if torch.is_tensor(im): # if PyTorch tensor, get numpy + im = im.cpu().numpy() + im = to_opencv_image(im) + im = im.astype(np.uint8) # typecast guard + if im.shape[0] == 3: # RGB image + # accommodate from RGB (numpy) to BGR (cv2) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + cv2.imshow('debug image', im) + cv2.waitKey(0)
+ + +
[docs]def mpl_debug_image(im): + '''Uses matplotlib to plot image with bigger size, axes, and false color on greyscaled images''' + import matplotlib.pyplot as plt + plt.figure() + plt.imshow(im) + plt.show()
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/lib/viz.html b/docs/build/html/_modules/convlab/lib/viz.html new file mode 100644 index 0000000..b8b18da --- /dev/null +++ b/docs/build/html/_modules/convlab/lib/viz.html @@ -0,0 +1,417 @@ + + + + + + + + + + + convlab.lib.viz — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.lib.viz

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import os
+
+import colorlover as cl
+import pydash as ps
+# The data visualization module
+# Defines plotting methods for analysis
+from plotly import graph_objs as go, io as pio, tools
+from plotly.offline import init_notebook_mode, iplot
+
+from convlab.lib import logger, util
+
+logger = logger.get_logger(__name__)
+
+# warn orca failure only once
+orca_warn_once = ps.once(lambda e: logger.warning(f'Failed to generate graph. Run retro-analysis to generate graphs later.'))
+if util.is_jupyter():
+    init_notebook_mode(connected=True)
+
+
+
[docs]def create_label(y_col, x_col, title=None, y_title=None, x_title=None, legend_name=None): + '''Create label dict for go.Layout with smart resolution''' + legend_name = legend_name or y_col + y_col_list, x_col_list, legend_name_list = ps.map_( + [y_col, x_col, legend_name], util.cast_list) + y_title = str(y_title or ','.join(y_col_list)) + x_title = str(x_title or ','.join(x_col_list)) + title = title or f'{y_title} vs {x_title}' + + label = { + 'y_title': y_title, + 'x_title': x_title, + 'title': title, + 'y_col_list': y_col_list, + 'x_col_list': x_col_list, + 'legend_name_list': legend_name_list, + } + return label
+ + +
[docs]def create_layout(title, y_title, x_title, x_type=None, width=500, height=500, layout_kwargs=None): + '''simplified method to generate Layout''' + layout = go.Layout( + title=title, + legend=dict(x=0.0, y=-0.25, orientation='h'), + yaxis=dict(rangemode='tozero', title=y_title), + xaxis=dict(type=x_type, title=x_title), + width=width, height=height, + margin=go.layout.Margin(l=60, r=60, t=60, b=60), + ) + layout.update(layout_kwargs) + return layout
+ + +
[docs]def get_palette(size): + '''Get the suitable palette of a certain size''' + if size <= 8: + palette = cl.scales[str(max(3, size))]['qual']['Set2'] + else: + palette = cl.interp(cl.scales['8']['qual']['Set2'], size) + return palette
+ + +
[docs]def lower_opacity(rgb, opacity): + return rgb.replace('rgb(', 'rgba(').replace(')', f',{opacity})')
+ + +
[docs]def plot(*args, **kwargs): + if util.is_jupyter(): + return iplot(*args, **kwargs)
+ + +
[docs]def plot_sr(sr, time_sr, title, y_title, x_title): + '''Plot a series''' + x = time_sr.tolist() + color = get_palette(1)[0] + main_trace = go.Scatter( + x=x, y=sr, mode='lines', showlegend=False, + line={'color': color, 'width': 1}, + ) + data = [main_trace] + layout = create_layout(title=title, y_title=y_title, x_title=x_title) + fig = go.Figure(data, layout) + plot(fig) + return fig
+ + +
[docs]def plot_mean_sr(sr_list, time_sr, title, y_title, x_title): + '''Plot a list of series using its mean, with error bar using std''' + mean_sr, std_sr = util.calc_srs_mean_std(sr_list) + max_sr = mean_sr + std_sr + min_sr = mean_sr - std_sr + max_y = max_sr.tolist() + min_y = min_sr.tolist() + x = time_sr.tolist() + color = get_palette(1)[0] + main_trace = go.Scatter( + x=x, y=mean_sr, mode='lines', showlegend=False, + line={'color': color, 'width': 1}, + ) + envelope_trace = go.Scatter( + x=x + x[::-1], y=max_y + min_y[::-1], showlegend=False, + line={'color': 'rgba(0, 0, 0, 0)'}, + fill='tozerox', fillcolor=lower_opacity(color, 0.2), + ) + data = [main_trace, envelope_trace] + layout = create_layout(title=title, y_title=y_title, x_title=x_title) + fig = go.Figure(data, layout) + return fig
+ + +
[docs]def save_image(figure, filepath): + if os.environ['PY_ENV'] == 'test': + return + filepath = util.smart_path(filepath) + try: + pio.write_image(figure, filepath) + except Exception as e: + orca_warn_once(e)
+ + +# analysis plot methods + +
[docs]def plot_session(session_spec, session_metrics, session_df, df_mode='eval'): + ''' + Plot the session graphs: + - mean_returns, strengths, sample_efficiencies, training_efficiencies, stabilities (with error bar) + - additional plots from session_df: losses, exploration variable, entropy + ''' + meta_spec = session_spec['meta'] + prepath = meta_spec['prepath'] + graph_prepath = meta_spec['graph_prepath'] + title = f'session graph: {session_spec["name"]} t{meta_spec["trial"]} s{meta_spec["session"]}' + + local_metrics = session_metrics['local'] + if df_mode == 'train': + name_time_pairs = [ + ('mean_return', 'frames'), + ] + for name, time in name_time_pairs: + fig = plot_sr( + local_metrics[name], local_metrics[time], title, name, time) + save_image(fig, f'{graph_prepath}_session_graph_{df_mode}_{name}_vs_{time}.png') + save_image(fig, f'{prepath}_session_graph_{df_mode}_{name}_vs_{time}.png') + name_time_pairs = [ + ('loss', 'frame'), + ('explore_var', 'frame'), + ('entropy', 'frame'), + ] + for name, time in name_time_pairs: + fig = plot_sr( + session_df[name], session_df[time], title, name, time) + save_image(fig, f'{graph_prepath}_session_graph_{df_mode}_{name}_vs_{time}.png') + else: + # training plots from session_df + name_time_pairs = [ + ('mean_return', 'frames'), + ('mean_length', 'frames'), + ('mean_success', 'frames'), + ] + for name, time in name_time_pairs: + fig = plot_sr( + local_metrics[name], local_metrics[time], title, name, time) + save_image(fig, f'{graph_prepath}_session_graph_{df_mode}_{name}_vs_{time}.png') + save_image(fig, f'{prepath}_session_graph_{df_mode}_{name}_vs_{time}.png')
+ + +
[docs]def plot_trial(trial_spec, trial_metrics): + ''' + Plot the trial graphs: + - mean_returns, strengths, sample_efficiencies, training_efficiencies, stabilities (with error bar) + - consistencies (no error bar) + ''' + meta_spec = trial_spec['meta'] + prepath = meta_spec['prepath'] + graph_prepath = meta_spec['graph_prepath'] + title = f'trial graph: {trial_spec["name"]} t{meta_spec["trial"]} {meta_spec["max_session"]} sessions' + + local_metrics = trial_metrics['local'] + name_time_pairs = [ + ('mean_return', 'frames'), + ('mean_length', 'frames'), + ('mean_success', 'frames'), + ] + for name, time in name_time_pairs: + fig = plot_mean_sr( + local_metrics[name], local_metrics[time], title, name, time) + save_image(fig, f'{graph_prepath}_trial_graph_{name}_vs_{time}.png') + save_image(fig, f'{prepath}_trial_graph_{name}_vs_{time}.png')
+ + +
[docs]def plot_experiment(experiment_spec, experiment_df, metrics_cols): + ''' + Plot the metrics vs. specs parameters of an experiment, where each point is a trial. + ref colors: https://plot.ly/python/heatmaps-contours-and-2dhistograms-tutorial/#plotlys-predefined-color-scales + ''' + y_cols = metrics_cols + x_cols = ps.difference(experiment_df.columns.tolist(), y_cols) + fig = tools.make_subplots(rows=len(y_cols), cols=len(x_cols), shared_xaxes=True, shared_yaxes=True, print_grid=False) + strength_sr = experiment_df['strength'] + min_strength = strength_sr.values.min() + max_strength = strength_sr.values.max() + for row_idx, y in enumerate(y_cols): + for col_idx, x in enumerate(x_cols): + x_sr = experiment_df[x] + guard_cat_x = x_sr.astype(str) if x_sr.dtype == 'object' else x_sr + trace = go.Scatter( + y=experiment_df[y], yaxis=f'y{row_idx+1}', + x=guard_cat_x, xaxis=f'x{col_idx+1}', + showlegend=False, mode='markers', + marker={ + 'symbol': 'circle-open-dot', 'color': experiment_df['strength'], 'opacity': 0.5, + # dump first quarter of colorscale that is too bright + 'cmin': min_strength - 0.50 * (max_strength - min_strength), 'cmax': max_strength, + 'colorscale': 'YlGnBu', 'reversescale': True + }, + ) + fig.add_trace(trace, row_idx + 1, col_idx + 1) + fig.layout[f'xaxis{col_idx+1}'].update(title='<br>'.join(ps.chunk(x, 20)), zerolinewidth=1, categoryarray=sorted(guard_cat_x.unique())) + fig.layout[f'yaxis{row_idx+1}'].update(title=y, rangemode='tozero') + fig.layout.update( + title=f'experiment graph: {experiment_spec["name"]}', + width=100 + 300 * len(x_cols), height=200 + 300 * len(y_cols)) + plot(fig) + graph_prepath = experiment_spec['meta']['graph_prepath'] + save_image(fig, f'{graph_prepath}_experiment_graph.png') + # save important graphs in prepath directly + prepath = experiment_spec['meta']['prepath'] + save_image(fig, f'{prepath}_experiment_graph.png') + return fig
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/action_decoder/multiwoz/multiwoz_vocab_action_decoder.html b/docs/build/html/_modules/convlab/modules/action_decoder/multiwoz/multiwoz_vocab_action_decoder.html new file mode 100644 index 0000000..b6aed29 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/action_decoder/multiwoz/multiwoz_vocab_action_decoder.html @@ -0,0 +1,349 @@ + + + + + + + + + + + convlab.modules.action_decoder.multiwoz.multiwoz_vocab_action_decoder — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.action_decoder.multiwoz.multiwoz_vocab_action_decoder
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.action_decoder.multiwoz.multiwoz_vocab_action_decoder

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import json
+import os
+
+from convlab.modules.policy.system.multiwoz.rule_based_multiwoz_bot import REF_SYS_DA, REF_USR_DA, generate_car, \
+    generate_ref_num
+from convlab.modules.util.multiwoz.dbquery import query
+
+DEFAULT_VOCAB_FILE=os.path.join(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))),
+    "data/multiwoz/da_slot_cnt.json")
+
+
+
[docs]class SkipException(Exception): + def __init__(self): + pass
+ + +
[docs]class ActionVocab(object): + def __init__(self, vocab_path=DEFAULT_VOCAB_FILE, num_actions=500): + # add general actions + self.vocab = [ + {'general-welcome': ['none']}, + {'general-greet': ['none']}, + {'general-bye': ['none']}, + {'general-reqmore': ['none']} + ] + # add single slot actions + for domain in REF_SYS_DA: + for slot in REF_SYS_DA[domain]: + self.vocab.append({domain + '-Inform': [slot]}) + self.vocab.append({domain + '-Request': [slot]}) + # add actions from stats + with open(vocab_path, 'r') as f: + stats = json.load(f) + for action_string in stats: + try: + act_strings = action_string.split(';];') + action_dict = {} + for act_string in act_strings: + if act_string == '': + continue + domain_act, slots = act_string.split('[', 1) + domain, act_type = domain_act.split('-') + if act_type in ['NoOffer', 'OfferBook']: + action_dict[domain_act] = ['none'] + elif act_type in ['Select']: + if slots.startswith('none'): + raise SkipException + action_dict[domain_act] = [slots.split(';')[0]] + else: + action_dict[domain_act] = sorted(slots.split(';')) + if action_dict not in self.vocab: + self.vocab.append(action_dict) + # else: + # print("Duplicate action", str(action_dict)) + except SkipException as e: + print(act_strings) + if len(self.vocab) >= num_actions: + break + print("{} actions are added to vocab".format(len(self.vocab))) + # pprint(self.vocab) + +
[docs] def get_action(self, action_index): + return self.vocab[action_index]
+ + +
[docs]class MultiWozVocabActionDecoder(object): + def __init__(self, vocab_path=None): + self.action_vocab = ActionVocab(num_actions=300) + +
[docs] def decode(self, action_index, state): + domains = ['Attraction', 'Hospital', 'Hotel', 'Restaurant', 'Taxi', 'Train', 'Police'] + delex_action = self.action_vocab.get_action(action_index) + action = {} + + for act in delex_action: + domain, act_type = act.split('-') + if act_type == 'Request': + action[act] = [] + for slot in delex_action[act]: + action[act].append([slot, '?']) + elif act == 'Booking-Book': + action['Booking-Book'] = [["Ref", generate_ref_num(8)]] + elif domain not in domains: + action[act] = [['none', 'none']] + else: + if act == 'Taxi-Inform': + for info_slot in ['leaveAt', 'arriveBy']: + if info_slot in state['belief_state']['taxi']['semi'] and \ + state['belief_state']['taxi']['semi'][info_slot] != "": + car = generate_car() + phone_num = generate_ref_num(11) + action[act] = [] + action[act].append(['Car', car]) + action[act].append(['Phone', phone_num]) + break + else: + action[act] = [['none', 'none']] + elif act in ['Train-Inform', 'Train-NoOffer', 'Train-OfferBook']: + for info_slot in ['departure', 'destination']: + if info_slot not in state['belief_state']['train']['semi'] or \ + state['belief_state']['train']['semi'][info_slot] == "": + action[act] = [['none', 'none']] + break + else: + for info_slot in ['leaveAt', 'arriveBy']: + if info_slot in state['belief_state']['train']['semi'] and \ + state['belief_state']['train']['semi'][info_slot] != "": + self.domain_fill(delex_action, state, action, act) + break + else: + action[act] = [['none', 'none']] + elif domain in domains: + self.domain_fill(delex_action, state, action, act) + + return action
+ +
[docs] def domain_fill(self, delex_action, state, action, act): + domain, act_type = act.split('-') + constraints = [] + for slot in state['belief_state'][domain.lower()]['semi']: + if state['belief_state'][domain.lower()]['semi'][slot] != "": + constraints.append([slot, state['belief_state'][domain.lower()]['semi'][slot]]) + if act_type in ['NoOffer', 'OfferBook']: # NoOffer['none'], OfferBook['none'] + action[act] = [] + for slot in constraints: + action[act].append([REF_USR_DA[domain].get(slot[0], slot[0]), slot[1]]) + elif act_type in ['Inform', 'Recommend', 'OfferBooked']: # Inform[Slot,...], Recommend[Slot, ...] + kb_result = query(domain.lower(), constraints) + # print("Policy Util") + # print(constraints) + # print(len(kb_result)) + if len(kb_result) == 0: + action[act] = [['none', 'none']] + else: + action[act] = [] + for slot in delex_action[act]: + if slot == 'Choice': + action[act].append([slot, len(kb_result)]) + elif slot == 'Ref': + action[act].append(["Ref", generate_ref_num(8)]) + else: + try: + action[act].append([slot, kb_result[0][REF_SYS_DA[domain].get(slot, slot)]]) + except: + action[act].append([slot, "N/A"]) + if len(action[act]) == 0: + action[act] = [['none', 'none']] + elif act_type in ['Select']: # Select[Slot] + kb_result = query(domain.lower(), constraints) + if len(kb_result) < 2: + action[act] = [['none', 'none']] + else: + slot = delex_action[act][0] + action[act] = [] + action[act].append([slot, kb_result[0][REF_SYS_DA[domain].get(slot, slot)]]) + action[act].append([slot, kb_result[1][REF_SYS_DA[domain].get(slot, slot)]]) + else: + print('Cannot decode:', str(delex_action)) + action[act] = [['none', 'none']]
+ +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/dst/multiwoz/dst_util.html b/docs/build/html/_modules/convlab/modules/dst/multiwoz/dst_util.html new file mode 100644 index 0000000..819f32d --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/dst/multiwoz/dst_util.html @@ -0,0 +1,447 @@ + + + + + + + + + + + convlab.modules.dst.multiwoz.dst_util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.dst.multiwoz.dst_util
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.dst.multiwoz.dst_util

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import re
+from difflib import SequenceMatcher
+
+init_belief_state = {
+        "police": {
+            "book": {
+                "booked": []
+            },
+            "semi": {}
+        },
+        "hotel": {
+            "book": {
+                "booked": [],
+                "people": "",
+                "day": "",
+                "stay": ""
+            },
+            "semi": {
+                "name": "",
+                "area": "",
+                "parking": "",
+                "pricerange": "",
+                "stars": "",
+                "internet": "",
+                "type": ""
+            }
+        },
+        "attraction": {
+            "book": {
+                "booked": []
+            },
+            "semi": {
+                "type": "",
+                "name": "",
+                "area": "",
+                "entrance fee": ""
+            }
+        },
+        "restaurant": {
+            "book": {
+                "booked": [],
+                "people": "",
+                "day": "",
+                "time": ""
+            },
+            "semi": {
+                "food": "",
+                "pricerange": "",
+                "name": "",
+                "area": "",
+            }
+        },
+        "hospital": {
+            "book": {
+                "booked": []
+            },
+            "semi": {
+                "department": ""
+            }
+        },
+        "taxi": {
+            "book": {
+                "booked": [],
+                "departure": "",
+                "destination": ""
+            },
+            "semi": {
+                "leaveAt": "",
+                "arriveBy": ""
+            }
+        },
+        "train": {
+            "book": {
+                "booked": [],
+                "people": "",
+                "trainID": ""
+            },
+            "semi": {
+                "leaveAt": "",
+                "destination": "",
+                "day": "",
+                "arriveBy": "",
+                "departure": ""
+            }
+        }
+    }
+
+
+
[docs]def init_state(): + """ + The init state to start a session. + Example: + state = { + 'user_action': None, + 'history': [], + 'belief_state': None, + 'request_state': {} + } + """ + # user_action = {'general-hello':{}} + user_action = {} + state = {'user_action': user_action, + 'belief_state': init_belief_state, + 'request_state': {}, + 'history': []} + return state
+ +
[docs]def str_similar(a, b): + return SequenceMatcher(None, a, b).ratio()
+ +def _log(info): + with open('fuzzy_recognition.log', 'a+') as f: + f.write('{}\n'.format(info)) + f.close() + +
[docs]def minDistance(word1, word2): + """The minimum edit distance between word 1 and 2.""" + if not word1: + return len(word2 or '') or 0 + if not word2: + return len(word1 or '') or 0 + size1 = len(word1) + size2 = len(word2) + tmp = list(range(size2 + 1)) + value = None + for i in range(size1): + tmp[0] = i + 1 + last = i + for j in range(size2): + if word1[i] == word2[j]: + value = last + else: + value = 1 + min(last, tmp[j], tmp[j + 1]) + last = tmp[j+1] + tmp[j+1] = value + return value
+ +
[docs]def normalize_value(value_set, domain, slot, value): + """ + Normalized the value produced by NLU module to map it to the ontology value space. + Args: + value_set (dict): The value set of task ontology. + domain (str): The domain of the slot-value pairs. + slot (str): The slot of the value. + value (str): The raw value detected by NLU module. + Returns: + value (str): The normalized value, which fits with the domain ontology. + """ + slot = slot.lower() + value = value.lower() + value = ' '.join(value.split()) + try: + assert domain in value_set + except: + raise Exception('domain <{}> not found in value set'.format(domain)) + if slot not in value_set[domain]: + raise Exception('slot <{}> not found in db_values[{}]'.format(slot, domain)) + value_list = value_set[domain][slot] + # exact match or containing match + v = _match_or_contain(value, value_list) + if v is not None: + return v + # some transfomations + cand_values = _transform_value(value) + for cv in cand_values: + v = _match_or_contain(cv, value_list) + if v is not None: + return v + # special value matching + v = special_match(domain, slot, value) + if v is not None: + return v + _log('Failed: domain {} slot {} value {}, raw value returned.'.format(domain, slot, value)) + return value
+ +def _transform_value(value): + cand_list = [] + # a 's -> a's + if " 's" in value: + cand_list.append(value.replace(" 's", "'s")) + # a - b -> a-b + if " - " in value: + cand_list.append(value.replace(" - ", "-")) + # center <-> centre + if value == 'center': + cand_list.append('centre') + elif value == 'centre': + cand_list.append('center') + # the + value + if not value.startswith('the '): + cand_list.append('the ' + value) + return cand_list + +def _match_or_contain(value, value_list): + """match value by exact match or containing""" + if value in value_list: + return value + for v in value_list: + if v in value or value in v: + return v + ## fuzzy match, when len(value) is large and distance(v1, v2) is small + for v in value_list: + d = minDistance(value, v) + if (d <= 2 and len(value) >= 10) or (d <= 3 and len(value) >= 15): + return v + return None + +
[docs]def special_match(domain, slot, value): + """special slot fuzzy matching""" + matched_result = None + if slot == 'arriveby' or slot == 'leaveat': + matched_result = _match_time(value) + elif slot == 'price' or slot == 'entrance fee': + matched_result = _match_pound_price(value) + elif slot == 'trainid': + matched_result = _match_trainid(value) + elif slot == 'duration': + matched_result = _match_duration(value) + return matched_result
+ +def _match_time(value): + """Return the time (leaveby, arriveat) in value, None if no time in value.""" + mat = re.search(r"(\d{1,2}:\d{1,2})", value) + if mat is not None and len(mat.groups()) > 0: + return mat.groups()[0] + return None + +def _match_trainid(value): + """Return the trainID in value, None if no trainID.""" + mat = re.search(r"TR(\d{4})", value) + if mat is not None and len(mat.groups()) > 0: + return mat.groups()[0] + return None + +def _match_pound_price(value): + """Return the price with pounds in value, None if no trainID.""" + mat = re.search(r"(\d{1,2},\d{1,2} pounds)", value) + if mat is not None and len(mat.groups()) > 0: + return mat.groups()[0] + mat = re.search(r"(\d{1,2} pounds)", value) + if mat is not None and len(mat.groups()) > 0: + return mat.groups()[0] + if "1 pound" in value.lower(): + return '1 pound' + if 'free' in value: + return 'free' + return None + +def _match_duration(value): + """Return the durations (by minute) in value, None if no trainID.""" + mat = re.search(r"(\d{1,2} minutes)", value) + if mat is not None and len(mat.groups()) > 0: + return mat.groups()[0] + return None + +if __name__ == "__main__": + # value_set = json.load(open('../../../data/multiwoz/db/db_values.json')) + # print(normalize_value(value_set, 'restaurant', 'address', 'regent street city center')) + print(minDistance("museum of archaeology and anthropology", "museum of archaelogy and anthropology")) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/dst/multiwoz/evaluate.html b/docs/build/html/_modules/convlab/modules/dst/multiwoz/evaluate.html new file mode 100644 index 0000000..76ddb9a --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/dst/multiwoz/evaluate.html @@ -0,0 +1,382 @@ + + + + + + + + + + + convlab.modules.dst.multiwoz.evaluate — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.dst.multiwoz.evaluate
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.dst.multiwoz.evaluate

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import json
+
+from convlab.modules.dst.multiwoz.dst_util import minDistance
+from convlab.modules.dst.multiwoz.rule_dst import RuleDST
+from convlab.modules.nlu.multiwoz.onenet.nlu import OneNetLU
+
+
+
[docs]class NLU_DST: + def __init__(self): + self.dst = RuleDST() + self.nlu = OneNetLU(archive_file='models/onenet.tar.gz', cuda_device=-1, + model_file='https://convlab.blob.core.windows.net/models/onenet.tar.gz') + +
[docs] def update(self, action, observation): + # update history + self.dst.state['history'].append([str(action)]) + + # NLU parsing + input_act = self.nlu.parse(observation, sum(self.dst.state['history'], [])) if self.nlu else observation + + # state tracking + self.dst.update(input_act) + + # update history + self.dst.state['history'][-1].append(str(observation)) + return self.dst.state
+ +
[docs] def reset(self): + self.dst.init_session() + self.dst.state['history'].append(['null'])
+ +
[docs]def load_data(path='../../../../data/multiwoz/test.json'): + data = json.load(open(path)) + result = [] + for id, session in data.items(): + log = session['log'] + turn_data = ['null'] + goal = session['goal'] + session_data = [] + for turn_idx, turn in enumerate(log): + if turn_idx % 2 == 0: # user + observation = turn['text'] + turn_data.append(observation) + else: # system + action = turn['text'] + golden_state = turn['metadata'] + turn_data.append(golden_state) + session_data.append(turn_data) + turn_data = [action] + result.append([session_data, goal]) + return result
+ +
[docs]def run_test(): + agent = NLU_DST() + agent.reset() + + test_data = load_data() + test_result = [] + for session_data, goal in test_data: + session_result = [] + for action, observation, golden_state in session_data: + pred_state = agent.update(action, observation) + session_result.append([golden_state, pred_state['belief_state']]) + test_result.append([session_result, goal]) + agent.reset() + return test_result
+ +
[docs]class ResultStat: + def __init__(self): + self.stat = {} + +
[docs] def add(self, domain, slot, score): + """score = 1 or 0""" + if domain in self.stat: + if slot in self.stat[domain]: + self.stat[domain][slot][0] += score + self.stat[domain][slot][1] += 1. + else: + self.stat[domain][slot] = [score, 1.] + else: + self.stat[domain] = {slot: [score, 1.]}
+ +
[docs] def domain_acc(self, domain): + domain_stat = self.stat[domain] + ret = [0, 0] + for _, r in domain_stat.items(): + ret[0] += r[0] + ret[1] += r[1] + return ret[0]/(ret[1] + 1e-10)
+ +
[docs] def slot_acc(self, domain, slot): + slot_stat = self.stat[domain][slot] + return slot_stat[0] / (slot_stat[1] + 1e-10)
+ +
[docs] def all_acc(self): + acc_result = {} + for domain in self.stat: + acc_result[domain] = {} + acc_result[domain]['acc'] = self.domain_acc(domain) + for slot in self.stat[domain]: + acc_result[domain][slot+'_acc'] = self.slot_acc(domain, slot) + return json.dumps(acc_result, indent=4)
+ +
[docs]def evaluate(test_result): + stat = ResultStat() + session_level = [0., 0.] + for session, goal in test_result: + last_pred_state = None + for golden_state, pred_state in session: # session + last_pred_state = pred_state + domains = golden_state.keys() + for domain in domains: # domain + if domain == 'bus': + continue + assert domain in pred_state, 'domain: {}'.format(domain) + golden_domain, pred_domain = golden_state[domain], pred_state[domain] + for slot, value in golden_domain['semi'].items(): # slot + if _is_empty(slot, golden_domain['semi']): + continue + pv = pred_domain['semi'][slot] if slot in pred_domain['semi'] else '_None' + score = 0. + if _is_match(value, pv): + score = 1. + stat.add(domain, slot, score) + if match_goal(last_pred_state, goal): + session_level[0] += 1 + session_level[1] += 1 + print('domain and slot-level acc:') + print(stat.all_acc()) + print('session-level acc: {}'.format(convert2acc(session_level[0], session_level[1])))
+ +
[docs]def convert2acc(a, b): + if b == 0: + return -1 + return a/b
+ +
[docs]def match_goal(pred_state, goal): + domains = pred_state.keys() + for domain in domains: + if domain not in goal: + continue + goal_domain = goal[domain] + if 'info' not in goal_domain: + continue + goal_domain_info = goal_domain['info'] + for slot, value in goal_domain_info.items(): + if slot in pred_state[domain]['semi']: + v = pred_state[domain]['semi'][slot] + else: + return False + if _is_match(value, v): + continue + elif _fuzzy_match(value, v): + continue + else: + return False + return True
+ +def _is_empty(slot, domain_state): + if slot not in domain_state: + return True + value = domain_state[slot] + if value is None or value == "" or value == 'null': + return True + return False + +def _is_match(value1, value2): + if not isinstance(value1, str) or not isinstance(value2, str): + return value1 == value2 + value1 = value1.lower() + value2 = value2.lower() + value1 = ' '.join(value1.strip().split()) + value2 = ' '.join(value2.strip().split()) + if value1 == value2: + return True + return False + +def _fuzzy_match(value1, value2): + if not isinstance(value1, str) or not isinstance(value2, str): + return value1 == value2 + value1 = value1.lower() + value2 = value2.lower() + value1 = ' '.join(value1.strip().split()) + value2 = ' '.join(value2.strip().split()) + d = minDistance(value1, value2) + if (len(value1) >= 10 and d <= 2) or (len(value1) >= 15 and d <= 3): + return True + return False + +if __name__ == '__main__': + test_result = run_test() + json.dump(test_result, open('dst_test_result.json', 'w+'), indent=2) + print('test session num: {}'.format(len(test_result))) + evaluate(test_result) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/dst/multiwoz/rule_dst.html b/docs/build/html/_modules/convlab/modules/dst/multiwoz/rule_dst.html new file mode 100644 index 0000000..ba10964 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/dst/multiwoz/rule_dst.html @@ -0,0 +1,265 @@ + + + + + + + + + + + convlab.modules.dst.multiwoz.rule_dst — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.dst.multiwoz.rule_dst
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.dst.multiwoz.rule_dst

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import copy
+import json
+import os
+
+import convlab
+from convlab.modules.dst.multiwoz.dst_util import init_state
+from convlab.modules.dst.multiwoz.dst_util import normalize_value
+from convlab.modules.dst.state_tracker import Tracker
+from convlab.modules.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
+
+
+
[docs]class RuleDST(Tracker): + """Rule based DST which trivially updates new values from NLU result to states.""" + def __init__(self): + Tracker.__init__(self) + self.state = init_state() + prefix = os.path.dirname(os.path.dirname(convlab.__file__)) + self.value_dict = json.load(open(prefix+'/data/multiwoz/value_dict.json')) + +
[docs] def update(self, user_act=None): + # print('------------------{}'.format(user_act)) + if not isinstance(user_act, dict): + raise Exception('Expect user_act to be <class \'dict\'> type but get {}.'.format(type(user_act))) + previous_state = self.state + new_belief_state = copy.deepcopy(previous_state['belief_state']) + new_request_state = copy.deepcopy(previous_state['request_state']) + for domain_type in user_act.keys(): + domain, tpe = domain_type.lower().split('-') + if domain in ['unk', 'general', 'booking']: + continue + if tpe == 'inform': + for k, v in user_act[domain_type]: + k = REF_SYS_DA[domain.capitalize()].get(k, k) + if k is None: + continue + try: + assert domain in new_belief_state + except: + raise Exception('Error: domain <{}> not in new belief state'.format(domain)) + domain_dic = new_belief_state[domain] + assert 'semi' in domain_dic + assert 'book' in domain_dic + + if k in domain_dic['semi']: + nvalue = normalize_value(self.value_dict, domain, k, v) + # if nvalue != v: + # _log('domain {} slot {} value {} -> {}'.format(domain, k, v, nvalue)) + new_belief_state[domain]['semi'][k] = nvalue + elif k in domain_dic['book']: + new_belief_state[domain]['book'][k] = v + elif k.lower() in domain_dic['book']: + new_belief_state[domain]['book'][k.lower()] = v + elif k == 'trainID' and domain == 'train': + new_belief_state[domain]['book'][k] = normalize_value(self.value_dict, domain, k, v) + else: + # raise Exception('unknown slot name <{}> of domain <{}>'.format(k, domain)) + with open('unknown_slot.log', 'a+') as f: + f.write('unknown slot name <{}> of domain <{}>\n'.format(k, domain)) + elif tpe == 'request': + for k, v in user_act[domain_type]: + k = REF_SYS_DA[domain.capitalize()].get(k, k) + if domain not in new_request_state: + new_request_state[domain] = {} + if k not in new_request_state[domain]: + new_request_state[domain][k] = 0 + + new_state = copy.deepcopy(previous_state) + new_state['belief_state'] = new_belief_state + new_state['request_state'] = new_request_state + new_state['user_action'] = user_act + + self.state = new_state + + return self.state
+ +
[docs] def init_session(self): + self.state = init_state()
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/dst/state_tracker.html b/docs/build/html/_modules/convlab/modules/dst/state_tracker.html new file mode 100644 index 0000000..9eb225c --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/dst/state_tracker.html @@ -0,0 +1,214 @@ + + + + + + + + + + + convlab.modules.dst.state_tracker — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.dst.state_tracker
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.dst.state_tracker

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+
+
[docs]class Tracker: + """Base class for dialog state tracker models.""" + def __init__(self): + """The constructor of Tracker class.""" + pass + +
[docs] def update(self, user_act=None): + """ + Update dialog state based on new user dialog act. + Args: + sess (Session Object): (for models implemented using tensorflow) The Session Object to assist model running. + user_act (dict or str): The dialog act (or utterance) of user input. The class of user_act depends on + the method of state tracker. For example, for rule-based tracker, type(user_act) == dict; while for + MDBT, type(user_act) == str. + Returns: + new_state (dict): Updated dialog state, with the same form of previous state. Note that the dialog state is + also a private data member. + """ + pass
+ +
[docs] def init_session(self): + """Init the Tracker to start a new session.""" + pass
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Mem2Seq/utils/measures.html b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Mem2Seq/utils/measures.html new file mode 100644 index 0000000..8de5b2e --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Mem2Seq/utils/measures.html @@ -0,0 +1,305 @@ + + + + + + + + + + + convlab.modules.e2e.multiwoz.Mem2Seq.utils.measures — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.e2e.multiwoz.Mem2Seq.utils.measures
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.e2e.multiwoz.Mem2Seq.utils.measures

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import os
+import re
+import subprocess
+import tempfile
+
+import numpy
+import numpy as np
+from six.moves import urllib
+
+
+
[docs]def wer(r, h): + """ + This is a function that calculate the word error rate in ASR. + You can use it like this: wer("what is it".split(), "what is".split()) + """ + #build the matrix + d = numpy.zeros((len(r)+1)*(len(h)+1), dtype=numpy.uint8).reshape((len(r)+1, len(h)+1)) + for i in range(len(r)+1): + for j in range(len(h)+1): + if i == 0: d[0][j] = j + elif j == 0: d[i][0] = i + for i in range(1,len(r)+1): + for j in range(1, len(h)+1): + if r[i-1] == h[j-1]: + d[i][j] = d[i-1][j-1] + else: + substitute = d[i-1][j-1] + 1 + insert = d[i][j-1] + 1 + delete = d[i-1][j] + 1 + d[i][j] = min(substitute, insert, delete) + result = float(d[len(r)][len(h)]) / len(r) * 100 + # result = str("%.2f" % result) + "%" + return result
+ +# -*- coding: utf-8 -*- +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLEU metric implementation. +""" + + +
[docs]def moses_multi_bleu(hypotheses, references, lowercase=False): + """Calculate the bleu score for hypotheses and references + using the MOSES ulti-bleu.perl script. + Args: + hypotheses: A numpy array of strings where each string is a single example. + references: A numpy array of strings where each string is a single example. + lowercase: If true, pass the "-lc" flag to the multi-bleu script + Returns: + The BLEU score as a float32 value. + """ + + if np.size(hypotheses) == 0: + return np.float32(0.0) + + + # Get MOSES multi-bleu script + try: + multi_bleu_path, _ = urllib.request.urlretrieve( + "https://raw.githubusercontent.com/moses-smt/mosesdecoder/" + "master/scripts/generic/multi-bleu.perl") + os.chmod(multi_bleu_path, 0o744) + except: #pylint: disable=W0702 + print("Unable to fetch multi-bleu.perl script, using local.") + metrics_dir = os.path.dirname(os.path.realpath(__file__)) + bin_dir = os.path.abspath(os.path.join(metrics_dir, "..", "..", "bin")) + multi_bleu_path = os.path.join(bin_dir, "tools/multi-bleu.perl") + + + # Dump hypotheses and references to tempfiles + hypothesis_file = tempfile.NamedTemporaryFile() + hypothesis_file.write("\n".join(hypotheses).encode("utf-8")) + hypothesis_file.write(b"\n") + hypothesis_file.flush() + reference_file = tempfile.NamedTemporaryFile() + reference_file.write("\n".join(references).encode("utf-8")) + reference_file.write(b"\n") + reference_file.flush() + + + # Calculate BLEU using multi-bleu script + with open(hypothesis_file.name, "r") as read_pred: + bleu_cmd = [multi_bleu_path] + if lowercase: + bleu_cmd += ["-lc"] + bleu_cmd += [reference_file.name] + try: + bleu_out = subprocess.check_output(bleu_cmd, stdin=read_pred, stderr=subprocess.STDOUT) + bleu_out = bleu_out.decode("utf-8") + bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1) + bleu_score = float(bleu_score) + except subprocess.CalledProcessError as error: + if error.output is not None: + print("multi-bleu.perl script returned non-zero exit code") + print(error.output) + bleu_score = np.float32(0.0) + + # Close temp files + hypothesis_file.close() + reference_file.close() + return bleu_score
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/Sequicity.html b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/Sequicity.html new file mode 100644 index 0000000..3ed6535 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/Sequicity.html @@ -0,0 +1,301 @@ + + + + + + + + + + + convlab.modules.e2e.multiwoz.Sequicity.Sequicity — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.e2e.multiwoz.Sequicity.Sequicity
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.e2e.multiwoz.Sequicity.Sequicity

+# -*- coding: utf-8 -*-
+
+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import os
+import random
+import zipfile
+
+import numpy as np
+import torch
+from nltk import word_tokenize
+from torch.autograd import Variable
+
+from convlab.lib.file_util import cached_path
+from convlab.modules.e2e.multiwoz.Sequicity.config import global_config as cfg
+from convlab.modules.e2e.multiwoz.Sequicity.model import Model
+from convlab.modules.e2e.multiwoz.Sequicity.reader import pad_sequences
+from convlab.modules.e2e.multiwoz.Sequicity.tsd_net import cuda_
+from convlab.modules.policy.system.policy import SysPolicy
+
+DEFAULT_CUDA_DEVICE=-1
+DEFAULT_DIRECTORY = "models"
+DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "Sequicity.rar")
+
+
[docs]def denormalize(uttr): + uttr = uttr.replace(' -s', 's') + uttr = uttr.replace(' -ly', 'ly') + uttr = uttr.replace(' -er', 'er') + return uttr
+ +
[docs]class Sequicity(SysPolicy): + def __init__(self, + archive_file=DEFAULT_ARCHIVE_FILE, + model_file=None): + SysPolicy.__init__(self) + + if not os.path.isfile(archive_file): + if not model_file: + raise Exception("No model for Sequicity is specified!") + archive_file = cached_path(model_file) + model_dir = os.path.dirname(os.path.abspath(__file__)) + if not os.path.exists(os.path.join(model_dir, 'data')): + archive = zipfile.ZipFile(archive_file, 'r') + archive.extractall(model_dir) + + cfg.init_handler('tsdf-multiwoz') + + torch.manual_seed(cfg.seed) + torch.cuda.manual_seed(cfg.seed) + random.seed(cfg.seed) + np.random.seed(cfg.seed) + self.m = Model('multiwoz') + self.m.count_params() + self.m.load_model() + self.reset() + +
[docs] def reset(self): + self.kw_ret = dict({'func':self.z2degree})
+ +
[docs] def z2degree(self, gen_z): + gen_bspan = self.m.reader.vocab.sentence_decode(gen_z, eos='EOS_Z2') + constraint_request = gen_bspan.split() + constraints = constraint_request[:constraint_request.index('EOS_Z1')] if 'EOS_Z1' \ + in constraint_request else constraint_request + for j, ent in enumerate(constraints): + constraints[j] = ent.replace('_', ' ') + degree = self.m.reader.db_search(constraints) + degree_input_list = self.m.reader._degree_vec_mapping(len(degree)) + degree_input = cuda_(Variable(torch.Tensor(degree_input_list).unsqueeze(0))) + return degree, degree_input
+ +
[docs] def predict(self, usr): + print('usr:', usr) + usr = word_tokenize(usr.lower()) + usr_words = usr + ['EOS_U'] + u_len = np.array([len(usr_words)]) + usr_indices = self.m.reader.vocab.sentence_encode(usr_words) + u_input_np = np.array(usr_indices)[:, np.newaxis] + u_input = cuda_(Variable(torch.from_numpy(u_input_np).long())) + m_idx, z_idx, degree = self.m.m(mode='test', degree_input=None, z_input=None, + u_input=u_input, u_input_np=u_input_np, u_len=u_len, + m_input=None, m_input_np=None, m_len=None, + turn_states=None, **self.kw_ret) + venue = random.sample(degree, 1)[0] if degree else dict() + l = [self.m.reader.vocab.decode(_) for _ in m_idx[0]] + if 'EOS_M' in l: + l = l[:l.index('EOS_M')] + l_origin = [] + for word in l: + if 'SLOT' in word: + word = word[:-5] + if word in venue.keys(): + value = venue[word] + if value != '?': + l_origin.append(value) + else: + l_origin.append(word) + sys = ' '.join(l_origin) + sys = denormalize(sys) + print('sys:', sys) + if cfg.prev_z_method == 'separate': + eob = self.m.reader.vocab.encode('EOS_Z2') + if eob in z_idx[0] and z_idx[0].index(eob) != len(z_idx[0]) - 1: + idx = z_idx[0].index(eob) + z_idx[0] = z_idx[0][:idx + 1] + for j, word in enumerate(z_idx[0]): + if word >= cfg.vocab_size: + z_idx[0][j] = 2 #unk + prev_z_input_np = pad_sequences(z_idx, cfg.max_ts, padding='post', truncating='pre').transpose((1, 0)) + prev_z_len = np.array([len(_) for _ in z_idx]) + prev_z_input = cuda_(Variable(torch.from_numpy(prev_z_input_np).long())) + self.kw_ret['prev_z_len'] = prev_z_len + self.kw_ret['prev_z_input'] = prev_z_input + self.kw_ret['prev_z_input_np'] = prev_z_input_np + return sys
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/metric.html b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/metric.html new file mode 100644 index 0000000..aa8dce5 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/metric.html @@ -0,0 +1,818 @@ + + + + + + + + + + + convlab.modules.e2e.multiwoz.Sequicity.metric — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.e2e.multiwoz.Sequicity.metric
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.e2e.multiwoz.Sequicity.metric

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import argparse
+import csv
+import functools
+import json
+import math
+from collections import Counter
+
+from nltk.corpus import stopwords
+from nltk.stem import WordNetLemmatizer
+from nltk.tokenize import word_tokenize
+from nltk.util import ngrams
+
+from convlab.modules.e2e.multiwoz.Sequicity.reader import clean_replace
+
+en_sws = set(stopwords.words())
+wn = WordNetLemmatizer()
+
+order_to_number = {
+    'first': 1, 'one': 1, 'seco': 2, 'two': 2, 'third': 3, 'three': 3, 'four': 4, 'forth': 4, 'five': 5, 'fifth': 5,
+    'six': 6, 'seven': 7, 'eight': 8, 'nin': 9, 'ten': 10, 'eleven': 11, 'twelve': 12
+}
+
+
[docs]def similar(a,b): + return a == b or a in b or b in a or a.split()[0] == b.split()[0] or a.split()[-1] == b.split()[-1]
+ #return a == b or b.endswith(a) or a.endswith(b) + +
[docs]def setsub(a,b): + junks_a = [] + useless_constraint = ['temperature','week','est ','quick','reminder','near'] + for i in a: + flg = False + for j in b: + if similar(i,j): + flg = True + if not flg: + junks_a.append(i) + for junk in junks_a: + flg = False + for item in useless_constraint: + if item in junk: + flg = True + if not flg: + return False + return True
+ +
[docs]def setsim(a,b): + a,b = set(a),set(b) + return setsub(a,b) and setsub(b,a)
+ +
[docs]class BLEUScorer(object): + ## BLEU score calculator via GentScorer interface + ## it calculates the BLEU-4 by taking the entire corpus in + ## Calulate based multiple candidates against multiple references + def __init__(self): + pass + +
[docs] def score(self, parallel_corpus): + + # containers + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + + # accumulate ngram statistics + for hyps, refs in parallel_corpus: + hyps = [hyp.split() for hyp in hyps] + refs = [ref.split() for ref in refs] + for hyp in hyps: + + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + + # computing bleu score + p0 = 1e-7 + bp = 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ + for i in range(4)] + s = math.fsum(w * math.log(p_n) \ + for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu
+ + +
[docs]def report(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + res = func(*args, **kwargs) + args[0].metric_dict[func.__name__ + ' '+str(args[2])] = res + return res + return wrapper
+ + +
[docs]class GenericEvaluator: + def __init__(self, result_path): + self.file = open(result_path,'r') + self.meta = [] + self.metric_dict = {} + self.entity_dict = {} + filename = result_path.split('/')[-1] + dump_dir = './sheets/' + filename.replace('.csv','.report.txt') + self.dump_file = open(dump_dir,'w') + + def _print_dict(self, dic): + for k, v in sorted(dic.items(),key=lambda x:x[0]): + print(k+'\t'+str(v)) + +
[docs] @report + def bleu_metric(self,data,type='bleu'): + gen, truth = [],[] + for row in data: + gen.append(row['generated_response']) + truth.append(row['response']) + wrap_generated = [[_] for _ in gen] + wrap_truth = [[_] for _ in truth] + sc = BLEUScorer().score(zip(wrap_generated, wrap_truth)) + return sc
+ +
[docs] def run_metrics(self): + raise ValueError('Please specify the evaluator first, bro')
+ +
[docs] def read_result_data(self): + while True: + line = self.file.readline() + if 'START_CSV_SECTION' in line: + break + self.meta.append(line) + reader = csv.DictReader(self.file) + data = [_ for _ in reader] + return data
+ + def _extract_constraint(self, z): + z = z.split() + if 'EOS_Z1' not in z: + return set(z).difference(['name', 'address', 'postcode', 'phone', 'area', 'pricerange', 'restaurant', + 'restaurants', 'style', 'price', 'food', 'EOS_M']) + else: + idx = z.index('EOS_Z1') + return set(z[:idx]).difference(['name', 'address', 'postcode', 'phone', 'area', 'pricerange', 'restaurant', + 'restaurants', 'style', 'price', 'food', 'EOS_M']) + + def _extract_request(self, z): + z = z.split() + if 'EOS_Z1' not in z or z[-1] == 'EOS_Z1': + return set() + else: + idx = z.index('EOS_Z1') + return set(z[idx+1:]) + +
[docs] def pack_dial(self,data): + dials = {} + for turn in data: + dial_id = int(turn['dial_id']) + if dial_id not in dials: + dials[dial_id] = [] + dials[dial_id].append(turn) + return dials
+ +
[docs] def dump(self): + self.dump_file.writelines(self.meta) + self.dump_file.write('START_REPORT_SECTION\n') + for k,v in self.metric_dict.items(): + self.dump_file.write('{}\t{}\n'.format(k,v))
+ + +
[docs] def clean(self,s): + s = s.replace('<go> ', '').replace(' SLOT', '_SLOT') + s = '<GO> ' + s + ' </s>' + for item in self.entity_dict: + # s = s.replace(item, 'VALUE_{}'.format(self.entity_dict[item])) + s = clean_replace(s, item, '{}_SLOT'.format(self.entity_dict[item])) + return s
+ + +
[docs]class CamRestEvaluator(GenericEvaluator): + def __init__(self, result_path): + super().__init__(result_path) + self.entities = [] + self.entity_dict = {} + +
[docs] def run_metrics(self): + raw_json = open('./data/CamRest676/CamRest676.json') + raw_entities = open('./data/CamRest676/CamRestOTGY.json') + raw_data = json.loads(raw_json.read().lower()) + raw_entities = json.loads(raw_entities.read().lower()) + self.get_entities(raw_entities) + data = self.read_result_data() + for i, row in enumerate(data): + data[i]['response'] = self.clean(data[i]['response']) + data[i]['generated_response'] = self.clean(data[i]['generated_response']) + bleu_score = self.bleu_metric(data,'bleu') + success_f1 = self.success_f1_metric(data, 'success') + match = self.match_metric(data, 'match', raw_data=raw_data) + self._print_dict(self.metric_dict) + return -success_f1[0]
+ +
[docs] def get_entities(self, entity_data): + for k in entity_data['informable']: + self.entities.extend(entity_data['informable'][k]) + for item in entity_data['informable'][k]: + self.entity_dict[item] = k
+ + def _extract_constraint(self, z): + z = z.split() + if 'EOS_Z1' not in z: + s = set(z) + else: + idx = z.index('EOS_Z1') + s = set(z[:idx]) + if 'moderately' in s: + s.discard('moderately') + s.add('moderate') + #print(self.entities) + #return s + return s.intersection(self.entities) + #return set(z).difference(['name', 'address', 'postcode', 'phone', 'area', 'pricerange']) + + def _extract_request(self, z): + z = z.split() + return set(z).intersection(['address', 'postcode', 'phone', 'area', 'pricerange','food']) + +
[docs] @report + def match_metric(self, data, sub='match',raw_data=None): + dials = self.pack_dial(data) + match,total = 0,1e-8 + success = 0 + # find out the last placeholder and see whether that is correct + # if no such placeholder, see the final turn, because it can be a yes/no question or scheduling dialogue + for dial_id in dials: + truth_req, gen_req = [], [] + dial = dials[dial_id] + gen_bspan, truth_cons, gen_cons = None, None, set() + truth_turn_num = -1 + truth_response_req = [] + for turn_num,turn in enumerate(dial): + if 'SLOT' in turn['generated_response']: + gen_bspan = turn['generated_bspan'] + gen_cons = self._extract_constraint(gen_bspan) + if 'SLOT' in turn['response']: + truth_cons = self._extract_constraint(turn['bspan']) + gen_response_token = turn['generated_response'].split() + response_token = turn['response'].split() + for idx, w in enumerate(gen_response_token): + if w.endswith('SLOT') and w != 'SLOT': + gen_req.append(w.split('_')[0]) + if w == 'SLOT' and idx != 0: + gen_req.append(gen_response_token[idx - 1]) + for idx, w in enumerate(response_token): + if w.endswith('SLOT') and w != 'SLOT': + truth_response_req.append(w.split('_')[0]) + if not gen_cons: + gen_bspan = dial[-1]['generated_bspan'] + gen_cons = self._extract_constraint(gen_bspan) + if truth_cons: + if gen_cons == truth_cons: + match += 1 + else: + print(gen_cons, truth_cons) + total += 1 + + return match / total, success / total
+ +
[docs] @report + def success_f1_metric(self, data, sub='successf1'): + dials = self.pack_dial(data) + tp,fp,fn = 0,0,0 + for dial_id in dials: + truth_req, gen_req = set(),set() + dial = dials[dial_id] + for turn_num, turn in enumerate(dial): + gen_response_token = turn['generated_response'].split() + response_token = turn['response'].split() + for idx, w in enumerate(gen_response_token): + if w.endswith('SLOT') and w != 'SLOT': + gen_req.add(w.split('_')[0]) + for idx, w in enumerate(response_token): + if w.endswith('SLOT') and w != 'SLOT': + truth_req.add(w.split('_')[0]) + + gen_req.discard('name') + truth_req.discard('name') + for req in gen_req: + if req in truth_req: + tp += 1 + else: + fp += 1 + for req in truth_req: + if req not in gen_req: + fn += 1 + precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1, precision, recall
+ +
[docs]class KvretEvaluator(GenericEvaluator): + def __init__(self, result_path): + super().__init__(result_path) + ent_json = open('./data/kvret/kvret_entities.json') + self.ent_data = json.loads(ent_json.read().lower()) + ent_json.close() + self._get_entity_dict(self.ent_data) + raw_json = open('./data/kvret/kvret_test_public.json') + self.raw_data = json.loads(raw_json.read().lower()) + raw_json.close() + +
[docs] def run_metrics(self): + data = self.read_result_data() + for i, row in enumerate(data): + data[i]['response'] = self.clean_by_intent(data[i]['response'],int(data[i]['dial_id'])) + data[i]['generated_response'] = self.clean_by_intent(data[i]['generated_response'],int(data[i]['dial_id'])) + match_rate = self.match_rate_metric(data, 'match') + bleu_score = self.bleu_metric(data,'bleu') + success_f1 = self.success_f1_metric(data,'success_f1') + self._print_dict(self.metric_dict)
+ +
[docs] def clean_by_intent(self,s,i): + s = s.replace('<go> ', '').replace(' SLOT', '_SLOT') + s = '<GO> ' + s + ' </s>' + intent = self.raw_data[i]['scenario']['task']['intent'] + slot = { + 'weather':['weather_attribute','location','weekly_time'], + 'navigate':['poi','poi_type','distance','traffic','address'], + 'schedule':['event','date','time','party','room','agenda'] + } + + for item in self.entity_dict: + if self.entity_dict[item] in slot[intent]: + # s = s.replace(item, 'VALUE_{}'.format(self.entity_dict[item])) + s = clean_replace(s, item, '{}_SLOT'.format(self.entity_dict[item])) + return s
+ + + def _extract_constraint(self, z): + z = z.split() + if 'EOS_Z1' not in z: + s = set(z) + else: + idx = z.index('EOS_Z1') + s = set(z[:idx]) + reqs = ['address', 'traffic', 'poi', 'poi_type', 'distance', 'weather', 'temperature', 'weather_attribute', + 'date', 'time', 'location', 'event', 'agenda', 'party', 'room', 'weekly_time', 'forecast'] + informable = { + 'weather': ['date','location','weather_attribute'], + 'navigate': ['poi_type','distance'], + 'schedule': ['event', 'date', 'time', 'agenda', 'party', 'room'] + } + infs = [] + for v in informable.values(): + infs.extend(v) + junk = ['good','great','quickest','shortest','route','week','fastest','nearest','next','closest','way','mile', + 'activity','restaurant','appointment' ] + s = s.difference(junk).difference(en_sws).difference(reqs) + res = set() + for item in s: + if item in junk: + continue + flg = False + for canon_ent in sorted(list(self.entity_dict.keys())): + if self.entity_dict[canon_ent] in infs: + if similar(item, canon_ent): + flg = True + junk.extend(canon_ent.split()) + res.add(canon_ent) + if flg: + break + return res + +
[docs] def constraint_same(self, truth_cons, gen_cons): + if not truth_cons and not gen_cons: + return True + if not truth_cons or not gen_cons: + return False + return setsim(gen_cons, truth_cons)
+ + def _get_entity_dict(self, entity_data): + entity_dict = {} + for k in entity_data: + if isinstance(entity_data[k][0], str): + for entity in entity_data[k]: + entity = self._lemmatize(self._tokenize(entity)) + entity_dict[entity] = k + if k in ['event','poi_type']: + entity_dict[entity.split()[0]] = k + elif isinstance(entity_data[k][0], dict): + for entity_entry in entity_data[k]: + for entity_type, entity in entity_entry.items(): + entity_type = 'poi_type' if entity_type == 'type' else entity_type + entity = self._lemmatize(self._tokenize(entity)) + entity_dict[entity] = entity_type + if entity_type in ['event', 'poi_type']: + entity_dict[entity.split()[0]] = entity_type + self.entity_dict = entity_dict + +
[docs] @report + def match_rate_metric(self, data, sub='match',bspans='./data/kvret/test.bspan.pkl'): + dials = self.pack_dial(data) + match,total = 0,1e-8 + #bspan_data = pickle.load(open(bspans,'rb')) + # find out the last placeholder and see whether that is correct + # if no such placeholder, see the final turn, because it can be a yes/no question or scheduling conversation + for dial_id in dials: + dial = dials[dial_id] + gen_bspan, truth_cons, gen_cons = None, None, set() + truth_turn_num = -1 + for turn_num,turn in enumerate(dial): + if 'SLOT' in turn['generated_response']: + gen_bspan = turn['generated_bspan'] + gen_cons = self._extract_constraint(gen_bspan) + if 'SLOT' in turn['response']: + truth_cons = self._extract_constraint(turn['bspan']) + + # KVRET dataset includes "scheduling" (so often no SLOT decoded in ground truth) + if not truth_cons: + truth_bspan = dial[-1]['bspan'] + truth_cons = self._extract_constraint(truth_bspan) + if not gen_cons: + gen_bspan = dial[-1]['generated_bspan'] + gen_cons = self._extract_constraint(gen_bspan) + + if truth_cons: + if self.constraint_same(gen_cons, truth_cons): + match += 1 + #print(gen_cons, truth_cons, '+') + else: + print(gen_cons, truth_cons, '-') + total += 1 + + return match / total
+ + def _tokenize(self, sent): + return ' '.join(word_tokenize(sent)) + + def _lemmatize(self, sent): + words = [wn.lemmatize(_) for _ in sent.split()] + #for idx,w in enumerate(words): + # if w != + return ' '.join(words) + +
[docs] @report + def success_f1_metric(self, data, sub='successf1'): + dials = self.pack_dial(data) + tp,fp,fn = 0,0,0 + for dial_id in dials: + truth_req, gen_req = set(),set() + dial = dials[dial_id] + for turn_num, turn in enumerate(dial): + gen_response_token = turn['generated_response'].split() + response_token = turn['response'].split() + for idx, w in enumerate(gen_response_token): + if w.endswith('SLOT') and w != 'SLOT': + gen_req.add(w.split('_')[0]) + for idx, w in enumerate(response_token): + if w.endswith('SLOT') and w != 'SLOT': + truth_req.add(w.split('_')[0]) + gen_req.discard('name') + truth_req.discard('name') + for req in gen_req: + if req in truth_req: + tp += 1 + else: + fp += 1 + for req in truth_req: + if req not in gen_req: + fn += 1 + precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1
+ +
[docs]class MultiWozEvaluator(GenericEvaluator): + def __init__(self, result_path): + super().__init__(result_path) + self.entities = [] + self.entity_dict = {} + +
[docs] def run_metrics(self): + with open('./data/MultiWoz/test.json') as f: + raw_data = json.loads(f.read().lower()) + with open('./data/MultiWoz/entities.json') as f: + raw_entities = json.loads(f.read().lower()) + self.get_entities(raw_entities) + data = self.read_result_data() + for i, row in enumerate(data): + data[i]['response'] = self.clean(data[i]['response']) + data[i]['generated_response'] = self.clean(data[i]['generated_response']) + bleu_score = self.bleu_metric(data,'bleu') + success_f1 = self.success_f1_metric(data, 'success') + match = self.match_metric(data, 'match', raw_data=raw_data) + self._print_dict(self.metric_dict) + return -success_f1[0]
+ +
[docs] def get_entities(self, entity_data): + for k in entity_data: + k_attr = k.split('_')[1][:-1] + self.entities.extend(entity_data[k]) + for item in entity_data[k]: + self.entity_dict[item] = k_attr
+ + def _extract_constraint(self, z): + z = z.split() + if 'EOS_Z1' not in z: + s = set(z) + else: + idx = z.index('EOS_Z1') + s = set(z[:idx]) + if 'moderately' in s: + s.discard('moderately') + s.add('moderate') + #print(self.entities) + #return s + return s.intersection(self.entities) + #return set(z).difference(['name', 'address', 'postcode', 'phone', 'area', 'pricerange']) + + def _extract_request(self, z): + z = z.split() + return set(z).intersection(['address', 'postcode', 'phone', 'area', 'pricerange','food']) + +
[docs] @report + def match_metric(self, data, sub='match',raw_data=None): + dials = self.pack_dial(data) + match,total = 0,1e-8 + # find out the last placeholder and see whether that is correct + # if no such placeholder, see the final turn, because it can be a yes/no question or scheduling dialogue + for dial_id in dials: + truth_req, gen_req = [], [] + dial = dials[dial_id] + gen_bspan, truth_cons, gen_cons = None, None, set() + truth_turn_num = -1 + truth_response_req = [] + for turn_num,turn in enumerate(dial): + if 'SLOT' in turn['generated_response']: + gen_bspan = turn['generated_bspan'] + gen_cons = self._extract_constraint(gen_bspan) + if 'SLOT' in turn['response']: + truth_cons = self._extract_constraint(turn['bspan']) + gen_response_token = turn['generated_response'].split() + response_token = turn['response'].split() + for idx, w in enumerate(gen_response_token): + if w.endswith('SLOT') and w != 'SLOT': + gen_req.append(w.split('_')[0]) + if w == 'SLOT' and idx != 0: + gen_req.append(gen_response_token[idx - 1]) + for idx, w in enumerate(response_token): + if w.endswith('SLOT') and w != 'SLOT': + truth_response_req.append(w.split('_')[0]) + if not gen_cons: + gen_bspan = dial[-1]['generated_bspan'] + gen_cons = self._extract_constraint(gen_bspan) + if truth_cons: + if gen_cons == truth_cons: + match += 1 + else: + pass +# print(gen_cons, truth_cons) + total += 1 + + return match / total
+ +
[docs] @report + def success_f1_metric(self, data, sub='successf1'): + dials = self.pack_dial(data) + tp,fp,fn = 0,0,0 + for dial_id in dials: + truth_req, gen_req = set(),set() + dial = dials[dial_id] + for turn_num, turn in enumerate(dial): + gen_response_token = turn['generated_response'].split() + response_token = turn['response'].split() + for idx, w in enumerate(gen_response_token): + if w.endswith('SLOT') and w != 'SLOT': + gen_req.add(w.split('_')[0]) + for idx, w in enumerate(response_token): + if w.endswith('SLOT') and w != 'SLOT': + truth_req.add(w.split('_')[0]) + + gen_req.discard('name') + truth_req.discard('name') + for req in gen_req: + if req in truth_req: + tp += 1 + else: + fp += 1 + for req in truth_req: + if req not in gen_req: + fn += 1 + precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1, precision, recall
+ +
[docs]def metric_handler(): + parser = argparse.ArgumentParser() + parser.add_argument('-file') + parser.add_argument('-type') + args = parser.parse_args() + ev_class = None + if args.type == 'camrest': + ev_class = CamRestEvaluator + elif args.type == 'kvret': + ev_class = KvretEvaluator + elif args.type == 'multiwoz': + ev_class = MultiWozEvaluator + ev = ev_class(args.file) + ev.run_metrics() + ev.dump()
+ +if __name__ == '__main__': + metric_handler() +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/model.html b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/model.html new file mode 100644 index 0000000..3c83fa4 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/model.html @@ -0,0 +1,706 @@ + + + + + + + + + + + convlab.modules.e2e.multiwoz.Sequicity.model — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.e2e.multiwoz.Sequicity.model
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.e2e.multiwoz.Sequicity.model

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import argparse
+import logging
+import random
+import time
+
+import numpy as np
+import torch
+from nltk import word_tokenize
+from torch.autograd import Variable
+from torch.optim import Adam
+
+from convlab.modules.e2e.multiwoz.Sequicity.config import global_config as cfg
+from convlab.modules.e2e.multiwoz.Sequicity.metric import CamRestEvaluator, KvretEvaluator, MultiWozEvaluator
+from convlab.modules.e2e.multiwoz.Sequicity.reader import CamRest676Reader, KvretReader, MultiWozReader
+from convlab.modules.e2e.multiwoz.Sequicity.reader import get_glove_matrix
+from convlab.modules.e2e.multiwoz.Sequicity.reader import pad_sequences
+from convlab.modules.e2e.multiwoz.Sequicity.tsd_net import TSD, cuda_
+
+
+
[docs]class Model: + def __init__(self, dataset): + reader_dict = { + 'camrest': CamRest676Reader, + 'kvret': KvretReader, + 'multiwoz': MultiWozReader + } + model_dict = { + 'TSD':TSD + } + evaluator_dict = { + 'camrest': CamRestEvaluator, + 'kvret': KvretEvaluator, + 'multiwoz': MultiWozEvaluator + } + self.reader = reader_dict[dataset]() + self.m = model_dict[cfg.m](embed_size=cfg.embedding_size, + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + layer_num=cfg.layer_num, + dropout_rate=cfg.dropout_rate, + z_length=cfg.z_length, + max_ts=cfg.max_ts, + beam_search=cfg.beam_search, + beam_size=cfg.beam_size, + eos_token_idx=self.reader.vocab.encode('EOS_M'), + vocab=self.reader.vocab, + teacher_force=cfg.teacher_force, + degree_size=cfg.degree_size) + self.EV = evaluator_dict[dataset] # evaluator class + if cfg.cuda: self.m = self.m.cuda() + self.base_epoch = -1 + + def _convert_batch(self, py_batch, prev_z_py=None): + u_input_py = py_batch['user'] + u_len_py = py_batch['u_len'] + kw_ret = {} + if cfg.prev_z_method == 'concat' and prev_z_py is not None: + for i in range(len(u_input_py)): + eob = self.reader.vocab.encode('EOS_Z2') + if eob in prev_z_py[i] and prev_z_py[i].index(eob) != len(prev_z_py[i]) - 1: + idx = prev_z_py[i].index(eob) + u_input_py[i] = prev_z_py[i][:idx + 1] + u_input_py[i] + else: + u_input_py[i] = prev_z_py[i] + u_input_py[i] + u_len_py[i] = len(u_input_py[i]) + for j, word in enumerate(prev_z_py[i]): + if word >= cfg.vocab_size: + prev_z_py[i][j] = 2 #unk + elif cfg.prev_z_method == 'separate' and prev_z_py is not None: + for i in range(len(prev_z_py)): + eob = self.reader.vocab.encode('EOS_Z2') + if eob in prev_z_py[i] and prev_z_py[i].index(eob) != len(prev_z_py[i]) - 1: + idx = prev_z_py[i].index(eob) + prev_z_py[i] = prev_z_py[i][:idx + 1] + for j, word in enumerate(prev_z_py[i]): + if word >= cfg.vocab_size: + prev_z_py[i][j] = 2 #unk + prev_z_input_np = pad_sequences(prev_z_py, cfg.max_ts, padding='post', truncating='pre').transpose((1, 0)) + prev_z_len = np.array([len(_) for _ in prev_z_py]) + prev_z_input = cuda_(Variable(torch.from_numpy(prev_z_input_np).long())) + kw_ret['prev_z_len'] = prev_z_len + kw_ret['prev_z_input'] = prev_z_input + kw_ret['prev_z_input_np'] = prev_z_input_np + + degree_input_np = np.array(py_batch['degree']) + u_input_np = pad_sequences(u_input_py, cfg.max_ts, padding='post', truncating='pre').transpose((1, 0)) + z_input_np = pad_sequences(py_batch['bspan'], padding='post').transpose((1, 0)) + m_input_np = pad_sequences(py_batch['response'], cfg.max_ts, padding='post', truncating='post').transpose( + (1, 0)) + + u_len = np.array(u_len_py) + m_len = np.array(py_batch['m_len']) + + degree_input = cuda_(Variable(torch.from_numpy(degree_input_np).float())) + u_input = cuda_(Variable(torch.from_numpy(u_input_np).long())) + z_input = cuda_(Variable(torch.from_numpy(z_input_np).long())) + m_input = cuda_(Variable(torch.from_numpy(m_input_np).long())) + + kw_ret['z_input_np'] = z_input_np + + return u_input, u_input_np, z_input, m_input, m_input_np,u_len, m_len, \ + degree_input, kw_ret + +
[docs] def train(self): + lr = cfg.lr + prev_min_loss, early_stop_count = 1 << 30, cfg.early_stop_count + train_time = 0 + for epoch in range(cfg.epoch_num): + sw = time.time() + if epoch <= self.base_epoch: + continue + self.training_adjust(epoch) + self.m.self_adjust(epoch) + sup_loss = 0 + sup_cnt = 0 + data_iterator = self.reader.mini_batch_iterator('train') + optim = Adam(lr=lr, params=filter(lambda x: x.requires_grad, self.m.parameters()),weight_decay=1e-5) + for iter_num, dial_batch in enumerate(data_iterator): + turn_states = {} + prev_z = None + for turn_num, turn_batch in enumerate(dial_batch): + if cfg.truncated: + logging.debug('iter %d turn %d' % (iter_num, turn_num)) + optim.zero_grad() + u_input, u_input_np, z_input, m_input, m_input_np, u_len, \ + m_len, degree_input, kw_ret \ + = self._convert_batch(turn_batch, prev_z) + + loss, pr_loss, m_loss, turn_states = self.m(u_input=u_input, z_input=z_input, + m_input=m_input, + degree_input=degree_input, + u_input_np=u_input_np, + m_input_np=m_input_np, + turn_states=turn_states, + u_len=u_len, m_len=m_len, mode='train', **kw_ret) + loss.backward(retain_graph=turn_num != len(dial_batch) - 1) + grad = torch.nn.utils.clip_grad_norm(self.m.parameters(), 5.0) + optim.step() + sup_loss += loss.item() + sup_cnt += 1 + logging.debug( + 'loss:{} pr_loss:{} m_loss:{} grad:{}'.format(loss.item(), + pr_loss.item(), + m_loss.item(), + grad)) + + prev_z = turn_batch['bspan'] + + epoch_sup_loss = sup_loss / (sup_cnt + 1e-8) + train_time += time.time() - sw + logging.info('Traning time: {}'.format(train_time)) + logging.info('avg training loss in epoch %d sup:%f' % (epoch, epoch_sup_loss)) + + valid_sup_loss, valid_unsup_loss = self.validate() + logging.info('validation loss in epoch %d sup:%f unsup:%f' % (epoch, valid_sup_loss, valid_unsup_loss)) + logging.info('time for epoch %d: %f' % (epoch, time.time()-sw)) + valid_loss = valid_sup_loss + valid_unsup_loss + + if valid_loss <= prev_min_loss: + self.save_model(epoch) + prev_min_loss = valid_loss + else: + early_stop_count -= 1 + lr *= cfg.lr_decay + if not early_stop_count: + break + logging.info('early stop countdown %d, learning rate %f' % (early_stop_count, lr))
+ +
[docs] def eval(self, data='test'): + self.m.eval() + self.reader.result_file = None + data_iterator = self.reader.mini_batch_iterator(data) + mode = 'test' if not cfg.pretrain else 'pretrain_test' + for batch_num, dial_batch in enumerate(data_iterator): + turn_states = {} + prev_z = None + for turn_num, turn_batch in enumerate(dial_batch): + u_input, u_input_np, z_input, m_input, m_input_np, u_len, \ + m_len, degree_input, kw_ret \ + = self._convert_batch(turn_batch, prev_z) + m_idx, z_idx, turn_states = self.m(mode=mode, u_input=u_input, u_len=u_len, z_input=z_input, + m_input=m_input, + degree_input=degree_input, u_input_np=u_input_np, + m_input_np=m_input_np, + m_len=m_len, turn_states=turn_states,**kw_ret) + self.reader.wrap_result(turn_batch, m_idx, z_idx, prev_z=prev_z) + prev_z = z_idx + ev = self.EV(result_path=cfg.result_path) + res = ev.run_metrics() + self.m.train() + return res
+ +
[docs] def interact(self): + def z2degree(gen_z): + gen_bspan = self.reader.vocab.sentence_decode(gen_z, eos='EOS_Z2') + constraint_request = gen_bspan.split() + constraints = constraint_request[:constraint_request.index('EOS_Z1')] if 'EOS_Z1' \ + in constraint_request else constraint_request + for j, ent in enumerate(constraints): + constraints[j] = ent.replace('_', ' ') + degree = self.reader.db_search(constraints) + degree_input_list = self.reader._degree_vec_mapping(len(degree)) + degree_input = cuda_(Variable(torch.Tensor(degree_input_list).unsqueeze(0))) + return degree, degree_input + + def denormalize(uttr): + uttr = uttr.replace(' -s', 's') + uttr = uttr.replace(' -ly', 'ly') + uttr = uttr.replace(' -er', 'er') + return uttr + + self.m.eval() + print('Start interaction.') + kw_ret = dict({'func':z2degree}) + while True: + usr = input('usr: ') + if usr == 'END': + break + if usr == 'RESET': + kw_ret = dict({'func':z2degree}) + continue + usr = word_tokenize(usr.lower()) + usr_words = usr + ['EOS_U'] + u_len = np.array([len(usr_words)]) + usr_indices = self.reader.vocab.sentence_encode(usr_words) + u_input_np = np.array(usr_indices)[:, np.newaxis] + u_input = cuda_(Variable(torch.from_numpy(u_input_np).long())) + m_idx, z_idx, degree = self.m(mode='test', degree_input=None, z_input=None, + u_input=u_input, u_input_np=u_input_np, u_len=u_len, + m_input=None, m_input_np=None, m_len=None, + turn_states=None, **kw_ret) + venue = random.sample(degree, 1)[0] if degree else dict() + l = [self.reader.vocab.decode(_) for _ in m_idx[0]] + if 'EOS_M' in l: + l = l[:l.index('EOS_M')] + l_origin = [] + for word in l: + if 'SLOT' in word: + word = word[:-5] + if word in venue.keys(): + value = venue[word] + if value != '?': + l_origin.append(value) + else: + l_origin.append(word) + sys = ' '.join(l_origin) + sys = denormalize(sys) + print('sys:', sys) + if cfg.prev_z_method == 'separate': + eob = self.reader.vocab.encode('EOS_Z2') + if eob in z_idx[0] and z_idx[0].index(eob) != len(z_idx[0]) - 1: + idx = z_idx[0].index(eob) + z_idx[0] = z_idx[0][:idx + 1] + for j, word in enumerate(z_idx[0]): + if word >= cfg.vocab_size: + z_idx[0][j] = 2 #unk + prev_z_input_np = pad_sequences(z_idx, cfg.max_ts, padding='post', truncating='pre').transpose((1, 0)) + prev_z_len = np.array([len(_) for _ in z_idx]) + prev_z_input = cuda_(Variable(torch.from_numpy(prev_z_input_np).long())) + kw_ret['prev_z_len'] = prev_z_len + kw_ret['prev_z_input'] = prev_z_input + kw_ret['prev_z_input_np'] = prev_z_input_np
+ +
[docs] def predict(self, usr, kw_ret): + def z2degree(gen_z): + gen_bspan = self.reader.vocab.sentence_decode(gen_z, eos='EOS_Z2') + constraint_request = gen_bspan.split() + constraints = constraint_request[:constraint_request.index('EOS_Z1')] if 'EOS_Z1' \ + in constraint_request else constraint_request + for j, ent in enumerate(constraints): + constraints[j] = ent.replace('_', ' ') + degree = self.reader.db_search(constraints) + degree_input_list = self.reader._degree_vec_mapping(len(degree)) + degree_input = cuda_(Variable(torch.Tensor(degree_input_list).unsqueeze(0))) + return degree, degree_input + + self.m.eval() + + kw_ret['func'] = z2degree + if 'prev_z_input_np' in kw_ret: + kw_ret['prev_z_len'] = np.array(kw_ret['prev_z_len']) + kw_ret['prev_z_input_np'] = np.array(kw_ret['prev_z_input_np']) + kw_ret['prev_z_input'] = cuda_(Variable(torch.Tensor(kw_ret['prev_z_input_np']).long())) + + usr = word_tokenize(usr.lower()) + + usr_words = usr + ['EOS_U'] + u_len = np.array([len(usr_words)]) + usr_indices = self.reader.vocab.sentence_encode(usr_words) + u_input_np = np.array(usr_indices)[:, np.newaxis] + u_input = cuda_(Variable(torch.from_numpy(u_input_np).long())) + m_idx, z_idx, degree = self.m(mode='test', degree_input=None, z_input=None, + u_input=u_input, u_input_np=u_input_np, u_len=u_len, + m_input=None, m_input_np=None, m_len=None, + turn_states=None, **kw_ret) + venue = random.sample(degree, 1)[0] if degree else dict() + l = [self.reader.vocab.decode(_) for _ in m_idx[0]] + if 'EOS_M' in l: + l = l[:l.index('EOS_M')] + l_origin = [] + for word in l: + if 'SLOT' in word: + word = word[:-5] + if word in venue.keys(): + value = venue[word] + if value != '?': + l_origin.append(value.replace(' ', '_')) + else: + l_origin.append(word) + sys = ' '.join(l_origin) + kw_ret['sys'] = sys + if cfg.prev_z_method == 'separate': + eob = self.reader.vocab.encode('EOS_Z2') + if eob in z_idx[0] and z_idx[0].index(eob) != len(z_idx[0]) - 1: + idx = z_idx[0].index(eob) + z_idx[0] = z_idx[0][:idx + 1] + for j, word in enumerate(z_idx[0]): + if word >= cfg.vocab_size: + z_idx[0][j] = 2 #unk + prev_z_input_np = pad_sequences(z_idx, cfg.max_ts, padding='post', truncating='pre').transpose((1, 0)) + prev_z_len = np.array([len(_) for _ in z_idx]) + kw_ret['prev_z_len'] = prev_z_len.tolist() + kw_ret['prev_z_input_np'] = prev_z_input_np.tolist() + if 'prev_z_input' in kw_ret: + del kw_ret['prev_z_input'] + + del kw_ret['func'] + + return kw_ret
+ +
[docs] def validate(self, data='dev'): + self.m.eval() + data_iterator = self.reader.mini_batch_iterator(data) + sup_loss, unsup_loss = 0, 0 + sup_cnt, unsup_cnt = 0, 0 + for dial_batch in data_iterator: + turn_states = {} + for turn_num, turn_batch in enumerate(dial_batch): + u_input, u_input_np, z_input, m_input, m_input_np, u_len, \ + m_len, degree_input, kw_ret \ + = self._convert_batch(turn_batch) + + loss, pr_loss, m_loss, turn_states = self.m(u_input=u_input, z_input=z_input, + m_input=m_input, + turn_states=turn_states, + degree_input=degree_input, + u_input_np=u_input_np, m_input_np=m_input_np, + u_len=u_len, m_len=m_len, mode='train',**kw_ret) + sup_loss += loss.item() + sup_cnt += 1 + logging.debug( + 'loss:{} pr_loss:{} m_loss:{}'.format(loss.item(), pr_loss.item(), m_loss.item())) + + sup_loss /= (sup_cnt + 1e-8) + unsup_loss /= (unsup_cnt + 1e-8) + self.m.train() + print('result preview...') + self.eval() + return sup_loss, unsup_loss
+ +
[docs] def reinforce_tune(self): + lr = cfg.lr + prev_min_loss, early_stop_count = 1 << 30, cfg.early_stop_count + for epoch in range(self.base_epoch + cfg.rl_epoch_num + 1): + mode = 'rl' + if epoch <= self.base_epoch: + continue + epoch_loss, cnt = 0,0 + data_iterator = self.reader.mini_batch_iterator('train') + optim = Adam(lr=lr, params=filter(lambda x: x.requires_grad, self.m.parameters()), weight_decay=1e-5) + for iter_num, dial_batch in enumerate(data_iterator): + turn_states = {} + prev_z = None + for turn_num, turn_batch in enumerate(dial_batch): + optim.zero_grad() + u_input, u_input_np, z_input, m_input, m_input_np, u_len, \ + m_len, degree_input, kw_ret \ + = self._convert_batch(turn_batch, prev_z) + loss_rl = self.m(u_input=u_input, z_input=z_input, + m_input=m_input, + degree_input=degree_input, + u_input_np=u_input_np, + m_input_np=m_input_np, + turn_states=turn_states, + u_len=u_len, m_len=m_len, mode=mode, **kw_ret) + + if loss_rl is not None: + loss = loss_rl + loss.backward() + grad = torch.nn.utils.clip_grad_norm(self.m.parameters(), 2.0) + optim.step() + epoch_loss += loss.item() + cnt += 1 + logging.debug('{} loss {}, grad:{}'.format(mode,loss.item(),grad)) + + prev_z = turn_batch['bspan'] + + epoch_sup_loss = epoch_loss / (cnt + 1e-8) + logging.info('avg training loss in epoch %d sup:%f' % (epoch, epoch_sup_loss)) + + valid_sup_loss, valid_unsup_loss = self.validate() + logging.info('validation loss in epoch %d sup:%f unsup:%f' % (epoch, valid_sup_loss, valid_unsup_loss)) + valid_loss = valid_sup_loss + valid_unsup_loss + + self.save_model(epoch) + + if valid_loss <= prev_min_loss: + #self.save_model(epoch) + prev_min_loss = valid_loss + else: + early_stop_count -= 1 + lr *= cfg.lr_decay + if not early_stop_count: + break + logging.info('early stop countdown %d, learning rate %f' % (early_stop_count, lr))
+ +
[docs] def save_model(self, epoch, path=None): + if not path: + path = cfg.model_path + all_state = {'lstd': self.m.state_dict(), + 'config': cfg.__dict__, + 'epoch': epoch} + torch.save(all_state, path)
+ +
[docs] def load_model(self, path=None): + if not path: + path = cfg.model_path + all_state = torch.load(path) + self.m.load_state_dict(all_state['lstd']) + self.base_epoch = all_state.get('epoch', 0)
+ +
[docs] def training_adjust(self, epoch): + return
+ +
[docs] def freeze_module(self, module): + for param in module.parameters(): + param.requires_grad = False
+ +
[docs] def unfreeze_module(self, module): + for param in module.parameters(): + param.requires_grad = True
+ +
[docs] def load_glove_embedding(self, freeze=False): + initial_arr = self.m.u_encoder.embedding.weight.data.cpu().numpy() + embedding_arr = torch.from_numpy(get_glove_matrix(self.reader.vocab, initial_arr)) + + self.m.u_encoder.embedding.weight.data.copy_(embedding_arr) + self.m.z_decoder.emb.weight.data.copy_(embedding_arr) + self.m.m_decoder.emb.weight.data.copy_(embedding_arr)
+ +
[docs] def count_params(self): + + module_parameters = filter(lambda p: p.requires_grad, self.m.parameters()) + param_cnt = sum([np.prod(p.size()) for p in module_parameters]) + + print('total trainable params: %d' % param_cnt)
+ + +
[docs]def main(arg_mode=None, arg_model=None): + + parser = argparse.ArgumentParser() + parser.add_argument('-mode') + parser.add_argument('-model') + parser.add_argument('-cfg', nargs='*') + args = parser.parse_args() + + if arg_mode is not None: + args.mode = arg_mode + if arg_model is not None: + args.model = arg_model + + cfg.init_handler(args.model) + + if args.cfg: + for pair in args.cfg: + k, v = tuple(pair.split('=')) + dtype = type(getattr(cfg, k)) + if isinstance(None, dtype): + raise ValueError() + if dtype is bool: + v = False if v == 'False' else True + else: + v = dtype(v) + setattr(cfg, k, v) + + logging.debug(str(cfg)) + if cfg.cuda: + logging.debug('Device: {}'.format(torch.cuda.current_device())) + cfg.mode = args.mode + + torch.manual_seed(cfg.seed) + torch.cuda.manual_seed(cfg.seed) + random.seed(cfg.seed) + np.random.seed(cfg.seed) + + m = Model(args.model.split('-')[-1]) + m.count_params() + if args.mode == 'train': + m.load_glove_embedding() + m.train() + elif args.mode == 'adjust': + m.load_model() + m.train() + elif args.mode == 'test': + m.load_model() + m.eval() + elif args.mode == 'rl': + m.load_model() + m.reinforce_tune() + elif args.mode == 'interact': + m.load_model() + m.interact() + elif args.mode == 'load': + m.load_model() + return m
+ +if __name__ == '__main__': + main() +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/reader.html b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/reader.html new file mode 100644 index 0000000..cdf1afb --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/reader.html @@ -0,0 +1,1307 @@ + + + + + + + + + + + convlab.modules.e2e.multiwoz.Sequicity.reader — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.e2e.multiwoz.Sequicity.reader
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.e2e.multiwoz.Sequicity.reader

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import csv
+import json
+import logging
+import os
+import pickle
+import random
+import re
+
+import numpy as np
+from nltk.stem import WordNetLemmatizer
+from nltk.tokenize import word_tokenize
+
+from convlab.modules.e2e.multiwoz.Sequicity.config import global_config as cfg
+
+
+
[docs]def clean_replace(s, r, t, forward=True, backward=False): + def clean_replace_single(s, r, t, forward, backward, sidx=0): + idx = s[sidx:].find(r) + if idx == -1: + return s, -1 + idx += sidx + idx_r = idx + len(r) + if backward: + while idx > 0 and s[idx - 1]: + idx -= 1 + elif idx > 0 and s[idx - 1] != ' ': + return s, -1 + + if forward: + while idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): + idx_r += 1 + elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): + return s, -1 + return s[:idx] + t + s[idx_r:], idx_r + + sidx = 0 + while sidx != -1: + s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) + return s
+ + +class _ReaderBase: + class LabelSet: + def __init__(self): + self._idx2item = {} + self._item2idx = {} + self._freq_dict = {} + + def __len__(self): + return len(self._idx2item) + + def _absolute_add_item(self, item): + idx = len(self) + self._idx2item[idx] = item + self._item2idx[item] = idx + + def add_item(self, item): + if item not in self._freq_dict: + self._freq_dict[item] = 0 + self._freq_dict[item] += 1 + + def construct(self, limit): + l = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) + print('Actual label size %d' % (len(l) + len(self._idx2item))) + if len(l) + len(self._idx2item) < limit: + logging.warning('actual label set smaller than that configured: {}/{}' + .format(len(l) + len(self._idx2item), limit)) + for item in l: + if item not in self._item2idx: + idx = len(self._idx2item) + self._idx2item[idx] = item + self._item2idx[item] = idx + if len(self._idx2item) >= limit: + break + + def encode(self, item): + return self._item2idx[item] + + def decode(self, idx): + return self._idx2item[idx] + + class Vocab(LabelSet): + def __init__(self, init=True): + _ReaderBase.LabelSet.__init__(self) + if init: + self._absolute_add_item('<pad>') # 0 + self._absolute_add_item('<go>') # 1 + self._absolute_add_item('<unk>') # 2 + self._absolute_add_item('<go2>') # 3 + + def load_vocab(self, vocab_path): + f = open(vocab_path, 'rb') + dic = pickle.load(f) + self._idx2item = dic['idx2item'] + self._item2idx = dic['item2idx'] + self._freq_dict = dic['freq_dict'] + f.close() + + def save_vocab(self, vocab_path): + f = open(vocab_path, 'wb') + dic = { + 'idx2item': self._idx2item, + 'item2idx': self._item2idx, + 'freq_dict': self._freq_dict + } + pickle.dump(dic, f) + f.close() + + def sentence_encode(self, word_list): + return [self.encode(_) for _ in word_list] + + def sentence_decode(self, index_list, eos=None): + l = [self.decode(_) for _ in index_list] + if not eos or eos not in l: + return ' '.join(l) + else: + idx = l.index(eos) + return ' '.join(l[:idx]) + + def nl_decode(self, l, eos=None): + return [self.sentence_decode(_, eos) + '\n' for _ in l] + + def encode(self, item): + if item in self._item2idx: + return self._item2idx[item] + else: + return self._item2idx['<unk>'] + + def decode(self, idx): + idx = np.int(idx) + if idx < len(self): + return self._idx2item[idx] + else: + return 'ITEM_%d' % (idx - cfg.vocab_size) + + def __init__(self): + self.train, self.dev, self.test = [], [], [] + self.vocab = self.Vocab() + self.result_file = '' + + def _construct(self, *args): + """ + load data, construct vocab and store them in self.train/dev/test + :param args: + :return: + """ + raise NotImplementedError('This is an abstract class, bro') + + def _bucket_by_turn(self, encoded_data): + turn_bucket = {} + for dial in encoded_data: + turn_len = len(dial) + if turn_len not in turn_bucket: + turn_bucket[turn_len] = [] + turn_bucket[turn_len].append(dial) + del_l = [] + for k in turn_bucket: + if k >= 5: del_l.append(k) + logging.debug("bucket %d instance %d" % (k, len(turn_bucket[k]))) + # for k in del_l: + # turn_bucket.pop(k) + return turn_bucket + + def _mark_batch_as_supervised(self, all_batches): + supervised_num = int(len(all_batches) * cfg.spv_proportion / 100) + for i, batch in enumerate(all_batches): + for dial in batch: + for turn in dial: + turn['supervised'] = i < supervised_num + if not turn['supervised']: + turn['degree'] = [0.] * cfg.degree_size # unsupervised learning. DB degree should be unknown + return all_batches + + def _construct_mini_batch(self, data): + all_batches = [] + batch = [] + for dial in data: + batch.append(dial) + if len(batch) == cfg.batch_size: + all_batches.append(batch) + batch = [] + # if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch + if len(batch) > 0.5 * cfg.batch_size: + all_batches.append(batch) + elif len(all_batches): + all_batches[-1].extend(batch) + else: + all_batches.append(batch) + return all_batches + + def _transpose_batch(self, batch): + dial_batch = [] + turn_num = len(batch[0]) + for turn in range(turn_num): + turn_l = {} + for dial in batch: + this_turn = dial[turn] + for k in this_turn: + if k not in turn_l: + turn_l[k] = [] + turn_l[k].append(this_turn[k]) + dial_batch.append(turn_l) + return dial_batch + + def mini_batch_iterator(self, set_name): + name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} + dial = name_to_set[set_name] + turn_bucket = self._bucket_by_turn(dial) + # self._shuffle_turn_bucket(turn_bucket) + all_batches = [] + for k in turn_bucket: + batches = self._construct_mini_batch(turn_bucket[k]) + all_batches += batches + self._mark_batch_as_supervised(all_batches) + random.shuffle(all_batches) + for i, batch in enumerate(all_batches): + yield self._transpose_batch(batch) + + def wrap_result(self, turn_batch, gen_m, gen_z, eos_syntax=None, prev_z=None): + """ + wrap generated results + :param gen_z: + :param gen_m: + :param turn_batch: dict of [i_1,i_2,...,i_b] with keys + :return: + """ + + results = [] + if eos_syntax is None: + eos_syntax = {'response': 'EOS_M', 'user': 'EOS_U', 'bspan': 'EOS_Z2'} + batch_size = len(turn_batch['user']) + for i in range(batch_size): + entry = {} + if prev_z is not None: + src = prev_z[i] + turn_batch['user'][i] + else: + src = turn_batch['user'][i] + for key in turn_batch: + entry[key] = turn_batch[key][i] + if key in eos_syntax: + entry[key] = self.vocab.sentence_decode(entry[key], eos=eos_syntax[key]) + if gen_m: + entry['generated_response'] = self.vocab.sentence_decode(gen_m[i], eos='EOS_M') + else: + entry['generated_response'] = '' + if gen_z: + entry['generated_bspan'] = self.vocab.sentence_decode(gen_z[i], eos='EOS_Z2') + else: + entry['generated_bspan'] = '' + results.append(entry) + write_header = False + if not self.result_file: + self.result_file = open(cfg.result_path, 'w') + self.result_file.write(str(cfg)) + write_header = True + + field = ['dial_id', 'turn_num', 'user', 'generated_bspan', 'bspan', 'generated_response', 'response', 'u_len', + 'm_len', 'supervised'] + for result in results: + del_k = [] + for k in result: + if k not in field: + del_k.append(k) + for k in del_k: + result.pop(k) + writer = csv.DictWriter(self.result_file, fieldnames=field) + if write_header: + self.result_file.write('START_CSV_SECTION\n') + writer.writeheader() + writer.writerows(results) + return results + + def db_search(self, constraints): + raise NotImplementedError('This is an abstract method') + + def db_degree_handler(self, z_samples, *args, **kwargs): + """ + returns degree of database searching and it may be used to control further decoding. + One hot vector, indicating the number of entries found: [0, 1, 2, 3, 4, >=5] + :param z_samples: nested list of B * [T] + :return: an one-hot control *numpy* control vector + """ + control_vec = [] + + for cons_idx_list in z_samples: + constraints = set() + for cons in cons_idx_list: + if not isinstance(cons, str): + cons = self.vocab.decode(cons) + if cons == 'EOS_Z1': + break + constraints.add(cons) + match_result = self.db_search(constraints) + degree = len(match_result) + # modified + # degree = 0 + control_vec.append(self._degree_vec_mapping(degree)) + return np.array(control_vec) + + def _degree_vec_mapping(self, match_num): + l = [0.] * cfg.degree_size + l[min(cfg.degree_size - 1, match_num)] = 1. + return l + + +
[docs]class CamRest676Reader(_ReaderBase): + def __init__(self): + super().__init__() + self._construct(cfg.data, cfg.db) + self.result_file = '' + + def _get_tokenized_data(self, raw_data, db_data, construct_vocab): + tokenized_data = [] + vk_map = self._value_key_map(db_data) + for dial_id, dial in enumerate(raw_data): + tokenized_dial = [] + for turn in dial['dial']: + turn_num = turn['turn'] + constraint = [] + requested = [] + for slot in turn['usr']['slu']: + if slot['act'] == 'inform': + s = slot['slots'][0][1] + if s not in ['dontcare', 'none']: + constraint.extend(word_tokenize(s)) + else: + requested.extend(word_tokenize(slot['slots'][0][1])) + degree = len(self.db_search(constraint)) + requested = sorted(requested) + constraint.append('EOS_Z1') + requested.append('EOS_Z2') + user = word_tokenize(turn['usr']['transcript']) + ['EOS_U'] + response = word_tokenize(self._replace_entity(turn['sys']['sent'], vk_map, constraint)) + ['EOS_M'] + tokenized_dial.append({ + 'dial_id': dial_id, + 'turn_num': turn_num, + 'user': user, + 'response': response, + 'constraint': constraint, + 'requested': requested, + 'degree': degree, + }) + if construct_vocab: + for word in user + response + constraint + requested: + self.vocab.add_item(word) + tokenized_data.append(tokenized_dial) + return tokenized_data + + def _replace_entity(self, response, vk_map, constraint): + response = re.sub('[cC][., ]*[bB][., ]*\d[., ]*\d[., ]*\w[., ]*\w', 'postcode_SLOT', response) + response = re.sub('\d{5}\s?\d{6}', 'phone_SLOT', response) + constraint_str = ' '.join(constraint) + for v, k in sorted(vk_map.items(), key=lambda x: -len(x[0])): + start_idx = response.find(v) + if start_idx == -1 \ + or (start_idx != 0 and response[start_idx - 1] != ' ') \ + or (v in constraint_str): + continue + if k not in ['name', 'address']: + response = clean_replace(response, v, k + '_SLOT', forward=True, backward=False) + else: + response = clean_replace(response, v, k + '_SLOT', forward=False, backward=False) + return response + + def _value_key_map(self, db_data): + requestable_keys = ['address', 'name', 'phone', 'postcode', 'food', 'area', 'pricerange'] + value_key = {} + for db_entry in db_data: + for k, v in db_entry.items(): + if k in requestable_keys: + value_key[v] = k + return value_key + + def _get_encoded_data(self, tokenized_data): + encoded_data = [] + for dial in tokenized_data: + encoded_dial = [] + prev_response = [] + for turn in dial: + user = self.vocab.sentence_encode(turn['user']) + response = self.vocab.sentence_encode(turn['response']) + constraint = self.vocab.sentence_encode(turn['constraint']) + requested = self.vocab.sentence_encode(turn['requested']) + degree = self._degree_vec_mapping(turn['degree']) + turn_num = turn['turn_num'] + dial_id = turn['dial_id'] + + # final input + encoded_dial.append({ + 'dial_id': dial_id, + 'turn_num': turn_num, + 'user': prev_response + user, + 'response': response, + 'bspan': constraint + requested, + 'u_len': len(prev_response + user), + 'm_len': len(response), + 'degree': degree, + }) + # modified + prev_response = response + encoded_data.append(encoded_dial) + return encoded_data + + def _split_data(self, encoded_data, split): + """ + split data into train/dev/test + :param encoded_data: list + :param split: tuple / list + :return: + """ + total = sum(split) + dev_thr = len(encoded_data) * split[0] // total + test_thr = len(encoded_data) * (split[0] + split[1]) // total + train, dev, test = encoded_data[:dev_thr], encoded_data[dev_thr:test_thr], encoded_data[test_thr:] + return train, dev, test + + def _construct(self, data_json_path, db_json_path): + """ + construct encoded train, dev, test set. + :param data_json_path: + :param db_json_path: + :return: + """ + construct_vocab = False + if not os.path.isfile(cfg.vocab_path): + construct_vocab = True + print('Constructing vocab file...') + raw_data_json = open(data_json_path) + raw_data = json.loads(raw_data_json.read().lower()) + db_json = open(db_json_path) + db_data = json.loads(db_json.read().lower()) + self.db = db_data + tokenized_data = self._get_tokenized_data(raw_data, db_data, construct_vocab) + if construct_vocab: + self.vocab.construct(cfg.vocab_size) + self.vocab.save_vocab(cfg.vocab_path) + else: + self.vocab.load_vocab(cfg.vocab_path) + encoded_data = self._get_encoded_data(tokenized_data) + self.train, self.dev, self.test = self._split_data(encoded_data, cfg.split) + random.shuffle(self.train) + random.shuffle(self.dev) + random.shuffle(self.test) + raw_data_json.close() + db_json.close() + +
+ + +
[docs]class KvretReader(_ReaderBase): + def __init__(self): + super().__init__() + + self.entity_dict = {} + self.abbr_dict = {} + + self.wn = WordNetLemmatizer() + self.db = {} + + self.tokenized_data_path = './data/kvret/' + self._construct(cfg.train, cfg.dev, cfg.test, cfg.entity) + + def _construct(self, train_json_path, dev_json_path, test_json_path, entity_json_path): + construct_vocab = False + if not os.path.isfile(cfg.vocab_path): + construct_vocab = True + print('Constructing vocab file...') + train_json, dev_json, test_json = open(train_json_path), open(dev_json_path), open(test_json_path) + entity_json = open(entity_json_path) + train_data, dev_data, test_data = json.loads(train_json.read().lower()), json.loads(dev_json.read().lower()), \ + json.loads(test_json.read().lower()) + entity_data = json.loads(entity_json.read().lower()) + self._get_entity_dict(entity_data) + + tokenized_train = self._get_tokenized_data(train_data, construct_vocab, 'train') + tokenized_dev = self._get_tokenized_data(dev_data, construct_vocab, 'dev') + tokenized_test = self._get_tokenized_data(test_data, construct_vocab, 'test') + + if construct_vocab: + self.vocab.construct(cfg.vocab_size) + self.vocab.save_vocab(cfg.vocab_path) + else: + self.vocab.load_vocab(cfg.vocab_path) + + self.train, self.dev, self.test = map(self._get_encoded_data, [tokenized_train, tokenized_dev, + tokenized_test]) + random.shuffle(self.train) + random.shuffle(self.dev) + random.shuffle(self.test) + + train_json.close() + dev_json.close() + test_json.close() + entity_json.close() + + def _save_tokenized_data(self, data, filename): + path = self.tokenized_data_path + filename + '.tokenized.json' + f = open(path,'w') + json.dump(data,f,indent=2) + f.close() + + def _load_tokenized_data(self, filename): + ''' + path = self.tokenized_data_path + filename + '.tokenized.json' + try: + f = open(path,'r') + except FileNotFoundError: + return None + data = json.load(f) + f.close() + return data + ''' + return None + + def _tokenize(self, sent): + return ' '.join(word_tokenize(sent)) + + def _lemmatize(self, sent): + return ' '.join([self.wn.lemmatize(_) for _ in sent.split()]) + + def _replace_entity(self, response, vk_map, prev_user_input, intent): + response = re.sub('\d+-?\d*fs?', 'temperature_SLOT', response) + response = re.sub('\d+\s?miles?', 'distance_SLOT', response) + response = re.sub('\d+\s\w+\s(dr)?(ct)?(rd)?(road)?(st)?(ave)?(way)?(pl)?\w*[.]?', 'address_SLOT', response) + response = self._lemmatize(self._tokenize(response)) + requestable = { + 'weather': ['weather_attribute'], + 'navigate': ['poi', 'traffic_info', 'address', 'distance'], + 'schedule': ['event', 'date', 'time', 'party', 'agenda', 'room'] + } + reqs = set() + for v, k in sorted(vk_map.items(), key=lambda x: -len(x[0])): + start_idx = response.find(v) + if start_idx == -1 or k not in requestable[intent]: + continue + end_idx = start_idx + len(v) + while end_idx < len(response) and response[end_idx] != ' ': + end_idx += 1 + # test whether they are indeed the same word + lm1, lm2 = v.replace('.', '').replace(' ', '').replace("'", ''), \ + response[start_idx:end_idx].replace('.', '').replace(' ', '').replace("'", '') + if lm1 == lm2 and lm1 not in prev_user_input and v not in prev_user_input: + response = clean_replace(response, response[start_idx:end_idx], k + '_SLOT') + reqs.add(k) + return response,reqs + + def _clean_constraint_dict(self, constraint_dict, intent, prefer='short'): + """ + clean the constraint dict so that every key is in "informable" and similar to one in provided entity dict. + :param constraint_dict: + :return: + """ + informable = { + 'weather': ['date', 'location', 'weather_attribute'], + 'navigate': ['poi_type', 'distance'], + 'schedule': ['event', 'date', 'time', 'agenda', 'party', 'room'] + } + + del_key = set(constraint_dict.keys()).difference(informable[intent]) + for key in del_key: + constraint_dict.pop(key) + invalid_key = [] + for k in constraint_dict: + constraint_dict[k] = constraint_dict[k].strip() + v = self._lemmatize(self._tokenize(constraint_dict[k])) + v = re.sub('(\d+) ([ap]m)', lambda x: x.group(1) + x.group(2), v) + v = re.sub('(\d+)\s?(mile)s?', lambda x: x.group(1) + ' ' + x.group(2), v) + if v in self.entity_dict: + if prefer == 'short': + constraint_dict[k] = v + elif prefer == 'long': + constraint_dict[k] = self.abbr_dict.get(v, v) + elif v.split()[0] in self.entity_dict: + if prefer == 'short': + constraint_dict[k] = v.split()[0] + elif prefer == 'long': + constraint_dict[k] = self.abbr_dict.get(v.split()[0], v) + else: + invalid_key.append(k) + for key in invalid_key: + constraint_dict.pop(key) + return constraint_dict + + def _get_tokenized_data(self, raw_data, add_to_vocab, data_type, is_test=False): + """ + Somerrthing to note: We define requestable and informable slots as below in further experiments + (including other baselines): + + informable = { + 'weather': ['date','location','weather_attribute'], + 'navigate': ['poi_type','distance'], + 'schedule': ['event'] + } + + requestable = { + 'weather': ['weather_attribute'], + 'navigate': ['poi','traffic','address','distance'], + 'schedule': ['event','date','time','party','agenda','room'] + } + :param raw_data: + :param add_to_vocab: + :param data_type: + :return: + """ + tokenized_data = self._load_tokenized_data(data_type) + if tokenized_data is not None: + logging.info('directly loading %s' % data_type) + return tokenized_data + tokenized_data = [] + state_dump = {} + for dial_id, raw_dial in enumerate(raw_data): + tokenized_dial = [] + prev_utter = '' + single_turn = {} + constraint_dict = {} + intent = raw_dial['scenario']['task']['intent'] + if cfg.intent != 'all' and cfg.intent != intent: + if intent not in ['navigate', 'weather', 'schedule']: + raise ValueError('what is %s intent bro?' % intent) + else: + continue + prev_response = [] + for turn_num, dial_turn in enumerate(raw_dial['dialogue']): + state_dump[(dial_id, turn_num)] = {} + if dial_turn['turn'] == 'driver': + u = self._lemmatize(self._tokenize(dial_turn['data']['utterance'])) + u = re.sub('(\d+) ([ap]m)', lambda x: x.group(1) + x.group(2), u) + single_turn['user'] = prev_response + u.split() + ['EOS_U'] + prev_utter += u + elif dial_turn['turn'] == 'assistant': + s = dial_turn['data']['utterance'] + # find entities and replace them + s = re.sub('(\d+) ([ap]m)', lambda x: x.group(1) + x.group(2), s) + s, reqs = self._replace_entity(s, self.entity_dict, prev_utter, intent) + single_turn['response'] = s.split() + ['EOS_M'] + # get constraints + if not constraint_dict: + constraint_dict = dial_turn['data']['slots'] + else: + for k, v in dial_turn['data']['slots'].items(): + constraint_dict[k] = v + constraint_dict = self._clean_constraint_dict(constraint_dict, intent) + + raw_constraints = constraint_dict.values() + raw_constraints = [self._lemmatize(self._tokenize(_)) for _ in raw_constraints] + + # add separator + constraints = [] + for item in raw_constraints: + if constraints: + constraints.append(';') + constraints.extend(item.split()) + # get requests + dataset_requested = set( + filter(lambda x: dial_turn['data']['requested'][x], dial_turn['data']['requested'].keys())) + requestable = { + 'weather': ['weather_attribute'], + 'navigate': ['poi', 'traffic', 'address', 'distance'], + 'schedule': ['date', 'time', 'party', 'agenda', 'room'] + } + requests = sorted(list(dataset_requested.intersection(reqs))) + + single_turn['constraint'] = constraints + ['EOS_Z1'] + single_turn['requested'] = requests + ['EOS_Z2'] + single_turn['turn_num'] = len(tokenized_dial) + single_turn['dial_id'] = dial_id + single_turn['degree'] = self.db_degree(constraints, raw_dial['scenario']['kb']['items']) + self.db[dial_id] = raw_dial['scenario']['kb']['items'] + if 'user' in single_turn: + state_dump[(dial_id, len(tokenized_dial))]['constraint'] = constraint_dict + state_dump[(dial_id, len(tokenized_dial))]['request'] = requests + tokenized_dial.append(single_turn) + prev_response = single_turn['response'] + single_turn = {} + if add_to_vocab: + for single_turn in tokenized_dial: + for word_token in single_turn['constraint'] + single_turn['requested'] + \ + single_turn['user'] + single_turn['response']: + self.vocab.add_item(word_token) + tokenized_data.append(tokenized_dial) + self._save_tokenized_data(tokenized_data, data_type) + return tokenized_data + + def _get_encoded_data(self, tokenized_data): + encoded_data = [] + for dial in tokenized_data: + new_dial = [] + for turn in dial: + turn['constraint'] = self.vocab.sentence_encode(turn['constraint']) + turn['requested'] = self.vocab.sentence_encode(turn['requested']) + turn['bspan'] = turn['constraint'] + turn['requested'] + turn['user'] = self.vocab.sentence_encode(turn['user']) + turn['response'] = self.vocab.sentence_encode(turn['response']) + turn['u_len'] = len(turn['user']) + turn['m_len'] = len(turn['response']) + turn['degree'] = self._degree_vec_mapping(turn['degree']) + new_dial.append(turn) + encoded_data.append(new_dial) + return encoded_data + + def _get_entity_dict(self, entity_data): + entity_dict = {} + for k in entity_data: + if isinstance(entity_data[k][0], str): + for entity in entity_data[k]: + entity = self._lemmatize(self._tokenize(entity)) + entity_dict[entity] = k + if k in ['event', 'poi_type']: + entity_dict[entity.split()[0]] = k + self.abbr_dict[entity.split()[0]] = entity + elif isinstance(entity_data[k][0], dict): + for entity_entry in entity_data[k]: + for entity_type, entity in entity_entry.items(): + entity_type = 'poi_type' if entity_type == 'type' else entity_type + entity = self._lemmatize(self._tokenize(entity)) + entity_dict[entity] = entity_type + if entity_type in ['event', 'poi_type']: + entity_dict[entity.split()[0]] = entity_type + self.abbr_dict[entity.split()[0]] = entity + self.entity_dict = entity_dict + +
[docs] def db_degree(self, constraints, items): + cnt = 0 + if items is not None: + for item in items: + item = item.values() + flg = True + for c in constraints: + itemvaluestr = " ".join(list(item)) + if c not in itemvaluestr: + flg = False + break + if flg: + cnt += 1 + return cnt
+ +
[docs] def db_degree_handler(self, z_samples, idx=None, *args, **kwargs): + control_vec = [] + for i,cons_idx_list in enumerate(z_samples): + constraints = set() + for cons in cons_idx_list: + if not isinstance(cons, str): + cons = self.vocab.decode(cons) + if cons == 'EOS_Z1': + break + constraints.add(cons) + items = self.db[idx[i]] + degree = self.db_degree(constraints, items) + control_vec.append(self._degree_vec_mapping(degree)) + return np.array(control_vec)
+ +
[docs]class MultiWozReader(_ReaderBase): + def __init__(self): + super().__init__() + self._construct(cfg.train, cfg.dev, cfg.test, cfg.db) + self.result_file = '' + + def _get_tokenized_data(self, raw_data, db_data, construct_vocab): + requestable_keys = ['addr', 'area', 'fee', 'name', 'phone', 'post', 'price', 'type', 'department', 'internet', 'parking', 'stars', 'food', 'arrive', 'day', 'depart', 'dest', 'leave', 'ticket', 'id'] + + tokenized_data = [] + vk_map = self._value_key_map(db_data) + for dial_id, dial in enumerate(raw_data): + tokenized_dial = [] + for turn in dial['dial']: + turn_num = turn['turn'] + constraint = [] + requested = [] + for slot_act in turn['usr']['slu']: + if slot_act == 'inform': + slot_values = turn['usr']['slu'][slot_act] + for v in slot_values: + s = v[1] + if s not in ['dont_care', 'none']: + constraint.append(s) + elif slot_act == 'request': + slot_values = turn['usr']['slu'][slot_act] + for v in slot_values: + s = v[0] + if s in requestable_keys: + requested.append(s) + degree = len(self.db_search(constraint)) + requested = sorted(requested) + constraint.append('EOS_Z1') + requested.append('EOS_Z2') + user = turn['usr']['transcript'].split() + ['EOS_U'] + response = self._replace_entity(turn['sys']['sent'], vk_map, constraint).split() + ['EOS_M'] + response_origin = turn['sys']['sent'].split() + tokenized_dial.append({ + 'dial_id': dial_id, + 'turn_num': turn_num, + 'user': user, + 'response': response, + 'response_origin': response_origin, + 'constraint': constraint, + 'requested': requested, + 'degree': degree, + }) + if construct_vocab: + for word in user + response + constraint + requested: + self.vocab.add_item(word) + tokenized_data.append(tokenized_dial) + return tokenized_data + + def _replace_entity(self, response, vk_map, constraint): + response = re.sub('[cC][., ]*[bB][., ]*\d[., ]*\d[., ]*\w[., ]*\w', 'postcode_SLOT', response) + response = re.sub('\d{5}\s?\d{6}', 'phone_SLOT', response) + constraint_str = ' '.join(constraint) + for v, k in sorted(vk_map.items(), key=lambda x: -len(x[0])): + start_idx = response.find(v) + if start_idx == -1 \ + or (start_idx != 0 and response[start_idx - 1] != ' ') \ + or (v in constraint_str): + continue + response = clean_replace(response, v, k + '_SLOT') + return response + + def _value_key_map(self, db_data): + def normal(string): + string = string.lower() + string = re.sub(r'\s*-\s*', '', string) + string = re.sub(r' ', '_', string) + string = re.sub(r',', '_,', string) + string = re.sub(r'\'', '_', string) + string = re.sub(r'\.', '_.', string) + string = re.sub(r'_+', '_', string) + string = re.sub(r'children', 'child_-s', string) + return string + requestable_dict = {'address':'addr', + 'area':'area', + 'entrance fee':'fee', + 'name':'name', + 'phone':'phone', + 'postcode':'post', + 'pricerange':'price', + 'type':'type', + 'department':'department', + 'internet':'internet', + 'parking':'parking', + 'stars':'stars', + 'food':'food', + 'arriveBy':'arrive', + 'day':'day', + 'departure':'depart', + 'destination':'dest', + 'leaveAt':'leave', + 'price':'ticket', + 'trainId':'id'} + value_key = {} + for db_entry in db_data: + for k, v in db_entry.items(): + if k in requestable_dict: + value_key[normal(v)] = requestable_dict[k] + return value_key + + def _get_encoded_data(self, tokenized_data): + encoded_data = [] + for dial in tokenized_data: + encoded_dial = [] + prev_response = [] + for turn in dial: + user = self.vocab.sentence_encode(turn['user']) + response = self.vocab.sentence_encode(turn['response']) + response_origin = ' '.join(turn['response_origin']) + constraint = self.vocab.sentence_encode(turn['constraint']) + requested = self.vocab.sentence_encode(turn['requested']) + degree = self._degree_vec_mapping(turn['degree']) + turn_num = turn['turn_num'] + dial_id = turn['dial_id'] + + # final input + encoded_dial.append({ + 'dial_id': dial_id, + 'turn_num': turn_num, + 'user': prev_response + user, + 'response': response, + 'response_origin': response_origin, + 'bspan': constraint + requested, + 'u_len': len(prev_response + user), + 'm_len': len(response), + 'degree': degree, + }) + # modified + prev_response = response + encoded_data.append(encoded_dial) + return encoded_data + + def _get_clean_db(self, raw_db_data): + for entry in raw_db_data: + for k, v in list(entry.items()): + if not isinstance(v, str) or v == '?': + entry.pop(k) + + def _construct(self, train_json_path, dev_json_path, test_json_path, db_json_path): + """ + construct encoded train, dev, test set. + :param train_json_path: + :param dev_json_path: + :param test_json_path: + :param db_json_path: list + :return: + """ + construct_vocab = False + if not os.path.isfile(cfg.vocab_path): + construct_vocab = True + print('Constructing vocab file...') + with open(train_json_path) as f: + train_raw_data = json.loads(f.read().lower()) + with open(dev_json_path) as f: + dev_raw_data = json.loads(f.read().lower()) + with open(test_json_path) as f: + test_raw_data = json.loads(f.read().lower()) + db_data = list() + for domain_db_json_path in db_json_path: + with open(domain_db_json_path) as f: + db_data += json.loads(f.read().lower()) + self._get_clean_db(db_data) + self.db = db_data + + train_tokenized_data = self._get_tokenized_data(train_raw_data, db_data, construct_vocab) + dev_tokenized_data = self._get_tokenized_data(dev_raw_data, db_data, construct_vocab) + test_tokenized_data = self._get_tokenized_data(test_raw_data, db_data, construct_vocab) + if construct_vocab: + self.vocab.construct(cfg.vocab_size) + self.vocab.save_vocab(cfg.vocab_path) + else: + self.vocab.load_vocab(cfg.vocab_path) + self.train = self._get_encoded_data(train_tokenized_data) + self.dev = self._get_encoded_data(dev_tokenized_data) + self.test = self._get_encoded_data(test_tokenized_data) + random.shuffle(self.train) + random.shuffle(self.dev) + random.shuffle(self.test) + + + +
[docs] def wrap_result(self, turn_batch, gen_m, gen_z, eos_syntax=None, prev_z=None): + """ + wrap generated results + :param gen_z: + :param gen_m: + :param turn_batch: dict of [i_1,i_2,...,i_b] with keys + :return: + """ + + results = [] + if eos_syntax is None: + eos_syntax = {'response': 'EOS_M', 'user': 'EOS_U', 'bspan': 'EOS_Z2'} + batch_size = len(turn_batch['user']) + for i in range(batch_size): + entry = {} + if prev_z is not None: + src = prev_z[i] + turn_batch['user'][i] + else: + src = turn_batch['user'][i] + for key in turn_batch: + entry[key] = turn_batch[key][i] + if key in eos_syntax: + entry[key] = self.vocab.sentence_decode(entry[key], eos=eos_syntax[key]) + if gen_z: + entry['generated_bspan'] = self.vocab.sentence_decode(gen_z[i], eos='EOS_Z2') + else: + entry['generated_bspan'] = '' + if gen_m: + entry['generated_response'] = self.vocab.sentence_decode(gen_m[i], eos='EOS_M') + constraint_request = entry['generated_bspan'].split() + constraints = constraint_request[:constraint_request.index('EOS_Z1')] if 'EOS_Z1' \ + in constraint_request else constraint_request + for j, ent in enumerate(constraints): + constraints[j] = ent.replace('_', ' ') + degree = self.db_search(constraints) + #print('constraints',constraints) + #print('degree',degree) + venue = random.sample(degree, 1)[0] if degree else dict() + l = [self.vocab.decode(_) for _ in gen_m[i]] + if 'EOS_M' in l: + l = l[:l.index('EOS_M')] + l_origin = [] + for word in l: + if 'SLOT' in word: + word = word[:-5] + if word in venue.keys(): + value = venue[word] + if value != '?': + l_origin.append(value.replace(' ', '_')) + else: + l_origin.append(word) + entry['generated_response_origin'] = ' '.join(l_origin) + else: + entry['generated_response'] = '' + entry['generated_response_origin'] = '' + results.append(entry) + write_header = False + if not self.result_file: + self.result_file = open(cfg.result_path, 'w') + self.result_file.write(str(cfg)) + write_header = True + + field = ['dial_id', 'turn_num', 'user', 'generated_bspan', 'bspan', 'generated_response', 'response', 'u_len', + 'm_len', 'supervised', 'generated_response_origin', 'response_origin'] + for result in results: + del_k = [] + for k in result: + if k not in field: + del_k.append(k) + for k in del_k: + result.pop(k) + writer = csv.DictWriter(self.result_file, fieldnames=field) + if write_header: + self.result_file.write('START_CSV_SECTION\n') + writer.writeheader() + writer.writerows(results) + return results
+ +
[docs]def pad_sequences(sequences, maxlen=None, dtype='int32', + padding='pre', truncating='pre', value=0.): + if not hasattr(sequences, '__len__'): + raise ValueError('`sequences` must be iterable.') + lengths = [] + for x in sequences: + if not hasattr(x, '__len__'): + raise ValueError('`sequences` must be a list of iterables. ' + 'Found non-iterable: ' + str(x)) + lengths.append(len(x)) + + num_samples = len(sequences) + seq_maxlen = np.max(lengths) + if maxlen is not None and cfg.truncated: + maxlen = min(seq_maxlen, maxlen) + else: + maxlen = seq_maxlen + # take the sample shape from the first non empty sequence + # checking for consistency in the main loop below. + sample_shape = tuple() + for s in sequences: + if len(s) > 0: + sample_shape = np.asarray(s).shape[1:] + break + + x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype) + for idx, s in enumerate(sequences): + if not len(s): + continue # empty list/array was found + if truncating == 'pre': + trunc = s[-maxlen:] + elif truncating == 'post': + trunc = s[:maxlen] + else: + raise ValueError('Truncating type "%s" not understood' % truncating) + + # check `trunc` has expected shape + trunc = np.asarray(trunc, dtype=dtype) + if trunc.shape[1:] != sample_shape: + raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' % + (trunc.shape[1:], idx, sample_shape)) + + if padding == 'post': + x[idx, :len(trunc)] = trunc + elif padding == 'pre': + x[idx, -len(trunc):] = trunc + else: + raise ValueError('Padding type "%s" not understood' % padding) + return x
+ + +
[docs]def get_glove_matrix(vocab, initial_embedding_np): + """ + return a glove embedding matrix + :param self: + :param glove_file: + :param initial_embedding_np: + :return: np array of [V,E] + """ + ef = open(cfg.glove_path, 'r') + cnt = 0 + vec_array = initial_embedding_np + old_avg = np.average(vec_array) + old_std = np.std(vec_array) + vec_array = vec_array.astype(np.float32) + new_avg, new_std = 0, 0 + + for line in ef.readlines(): + line = line.strip().split(' ') + word, vec = line[0], line[1:] + vec = np.array(vec, np.float32) + word_idx = vocab.encode(word) + if word.lower() in ['unk', '<unk>'] or word_idx != vocab.encode('<unk>'): + cnt += 1 + vec_array[word_idx] = vec + new_avg += np.average(vec) + new_std += np.std(vec) + new_avg /= cnt + new_std /= cnt + ef.close() + logging.info('%d known embedding. old mean: %f new mean %f, old std %f new std %f' % (cnt, old_avg, + new_avg, old_std, new_std)) + return vec_array
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/tsd_net.html b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/tsd_net.html new file mode 100644 index 0000000..0d51894 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/e2e/multiwoz/Sequicity/tsd_net.html @@ -0,0 +1,875 @@ + + + + + + + + + + + convlab.modules.e2e.multiwoz.Sequicity.tsd_net — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.e2e.multiwoz.Sequicity.tsd_net
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.e2e.multiwoz.Sequicity.tsd_net

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import copy
+import math
+import random
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.autograd import Variable
+from torch.distributions import Categorical
+
+from convlab.modules.e2e.multiwoz.Sequicity.config import global_config as cfg
+from convlab.modules.e2e.multiwoz.Sequicity.reader import pad_sequences
+
+
+
[docs]def cuda_(var): + return var.cuda() if cfg.cuda else var
+ + +
[docs]def toss_(p): + return random.randint(0, 99) <= p
+ + +
[docs]def nan(v): + if isinstance(v, float): + return v == float('nan') + return np.isnan(np.sum(v.data.cpu().numpy()))
+ +
[docs]def get_sparse_input_aug(x_input_np): + """ + sparse input of + :param x_input_np: [T,B] + :return: Numpy array: [B,T,aug_V] + """ + ignore_index = [0] + unk = 2 + result = np.zeros((x_input_np.shape[0], x_input_np.shape[1], cfg.vocab_size + x_input_np.shape[0]), + dtype=np.float32) + result.fill(1e-10) + for t in range(x_input_np.shape[0]): + for b in range(x_input_np.shape[1]): + w = x_input_np[t][b] + if w not in ignore_index: + if w != unk: + result[t][b][x_input_np[t][b]] = 1.0 + else: + result[t][b][cfg.vocab_size + t] = 1.0 + result_np = result.transpose((1, 0, 2)) + result = torch.from_numpy(result_np).float() + return result
+ +
[docs]def init_gru(gru): + gru.reset_parameters() + for _, hh, _, _ in gru.all_weights: + for i in range(0, hh.size(0), gru.hidden_size): + torch.nn.init.orthogonal(hh[i:i+gru.hidden_size],gain=1)
+ +
[docs]class Attn(nn.Module): + def __init__(self, hidden_size): + super(Attn, self).__init__() + self.hidden_size = hidden_size + self.attn = nn.Linear(self.hidden_size * 2, hidden_size) + self.v = nn.Parameter(torch.zeros(hidden_size)) + stdv = 1. / math.sqrt(self.v.size(0)) + self.v.data.normal_(mean=0, std=stdv) + +
[docs] def forward(self, hidden, encoder_outputs, normalize=True): + encoder_outputs = encoder_outputs.transpose(0, 1) # [B,T,H] + attn_energies = self.score(hidden, encoder_outputs) + normalized_energy = F.softmax(attn_energies, dim=2) # [B,1,T] + context = torch.bmm(normalized_energy, encoder_outputs) # [B,1,H] + return context.transpose(0, 1) # [1,B,H]
+ +
[docs] def score(self, hidden, encoder_outputs): + max_len = encoder_outputs.size(1) + H = hidden.repeat(max_len, 1, 1).transpose(0, 1) + energy = F.tanh(self.attn(torch.cat([H, encoder_outputs], 2))) # [B,T,2H]->[B,T,H] + energy = energy.transpose(2, 1) # [B,H,T] + v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1) # [B,1,H] + energy = torch.bmm(v, energy) # [B,1,T] + return energy
+ +
[docs]class SimpleDynamicEncoder(nn.Module): + def __init__(self, input_size, embed_size, hidden_size, n_layers, dropout): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.embed_size = embed_size + self.n_layers = n_layers + self.dropout = dropout + self.embedding = nn.Embedding(input_size, embed_size) + self.gru = nn.GRU(embed_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True) + init_gru(self.gru) + +
[docs] def forward(self, input_seqs, input_lens, hidden=None): + """ + forward procedure. No need for inputs to be sorted + :param input_seqs: Variable of [T,B] + :param hidden: + :param input_lens: *numpy array* of len for each input sequence + :return: + """ + batch_size = input_seqs.size(1) + embedded = self.embedding(input_seqs) + embedded = embedded.transpose(0, 1) # [B,T,E] + sort_idx = np.argsort(-input_lens) + unsort_idx = cuda_(torch.LongTensor(np.argsort(sort_idx))) + input_lens = input_lens[sort_idx] + sort_idx = cuda_(torch.LongTensor(sort_idx)) + embedded = embedded[sort_idx].transpose(0, 1) # [T,B,E] + packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lens) + outputs, hidden = self.gru(packed, hidden) + + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] + outputs = outputs.transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() + hidden = hidden.transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() + return outputs, hidden, embedded
+ + +
[docs]class BSpanDecoder(nn.Module): + def __init__(self, embed_size, hidden_size, vocab_size, dropout_rate): + super().__init__() + self.gru = nn.GRU(hidden_size + embed_size, hidden_size, dropout=dropout_rate) + self.proj = nn.Linear(hidden_size * 2, vocab_size) + self.emb = nn.Embedding(vocab_size, embed_size) + self.attn_u = Attn(hidden_size) + self.proj_copy1 = nn.Linear(hidden_size, hidden_size) + self.proj_copy2 = nn.Linear(hidden_size, hidden_size) + self.dropout_rate = dropout_rate + init_gru(self.gru) + self.inp_dropout = nn.Dropout(self.dropout_rate) + +
[docs] def forward(self, u_enc_out, z_tm1, last_hidden, u_input_np, pv_z_enc_out, prev_z_input_np, u_emb, pv_z_emb): + + sparse_u_input = Variable(get_sparse_input_aug(u_input_np), requires_grad=False) + + if pv_z_enc_out is not None: + context = self.attn_u(last_hidden, torch.cat([pv_z_enc_out, u_enc_out], dim=0)) + else: + context = self.attn_u(last_hidden, u_enc_out) + embed_z = self.emb(z_tm1) + #embed_z = F.dropout(embed_z, self.dropout_rate) + #embed_z = self.inp_dropout(embed_z) + + gru_in = torch.cat([embed_z, context], 2) + gru_out, last_hidden = self.gru(gru_in, last_hidden) + #gru_out = F.dropout(gru_out, self.dropout_rate) + #gru_out = self.inp_dropout(gru_out) + gen_score = self.proj(torch.cat([gru_out, context], 2)).squeeze(0) + #gen_score = F.dropout(gen_score, self.dropout_rate) + #gen_score = self.inp_dropout(gen_score) + u_copy_score = F.tanh(self.proj_copy1(u_enc_out.transpose(0, 1))) # [B,T,H] + # stable version of copynet + u_copy_score = torch.matmul(u_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) + u_copy_score = u_copy_score.cpu() + u_copy_score_max = torch.max(u_copy_score, dim=1, keepdim=True)[0] + u_copy_score = torch.exp(u_copy_score - u_copy_score_max) # [B,T] + u_copy_score = torch.log(torch.bmm(u_copy_score.unsqueeze(1), sparse_u_input)).squeeze( + 1) + u_copy_score_max # [B,V] + u_copy_score = cuda_(u_copy_score) + if pv_z_enc_out is None: + #u_copy_score = F.dropout(u_copy_score, self.dropout_rate) + #u_copy_score = self.inp_dropout(u_copy_score) + scores = F.softmax(torch.cat([gen_score, u_copy_score], dim=1), dim=1) + gen_score, u_copy_score = scores[:, :cfg.vocab_size], \ + scores[:, cfg.vocab_size:] + proba = gen_score + u_copy_score[:, :cfg.vocab_size] # [B,V] + proba = torch.cat([proba, u_copy_score[:, cfg.vocab_size:]], 1) + else: + sparse_pv_z_input = Variable(get_sparse_input_aug(prev_z_input_np), requires_grad=False) + pv_z_copy_score = F.tanh(self.proj_copy2(pv_z_enc_out.transpose(0, 1))) # [B,T,H] + pv_z_copy_score = torch.matmul(pv_z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) + pv_z_copy_score = pv_z_copy_score.cpu() + pv_z_copy_score_max = torch.max(pv_z_copy_score, dim=1, keepdim=True)[0] + pv_z_copy_score = torch.exp(pv_z_copy_score - pv_z_copy_score_max) # [B,T] + pv_z_copy_score = torch.log(torch.bmm(pv_z_copy_score.unsqueeze(1), sparse_pv_z_input)).squeeze( + 1) + pv_z_copy_score_max # [B,V] + pv_z_copy_score = cuda_(pv_z_copy_score) + scores = F.softmax(torch.cat([gen_score, u_copy_score, pv_z_copy_score], dim=1), dim=1) + gen_score, u_copy_score, pv_z_copy_score = scores[:, :cfg.vocab_size], \ + scores[:, + cfg.vocab_size:2 * cfg.vocab_size + u_input_np.shape[0]], \ + scores[:, 2 * cfg.vocab_size + u_input_np.shape[0]:] + proba = gen_score + u_copy_score[:, :cfg.vocab_size] + pv_z_copy_score[:, :cfg.vocab_size] # [B,V] + proba = torch.cat([proba, pv_z_copy_score[:, cfg.vocab_size:], u_copy_score[:, cfg.vocab_size:]], 1) + return gru_out, last_hidden, proba
+ + +
[docs]class ResponseDecoder(nn.Module): + def __init__(self, embed_size, hidden_size, vocab_size, degree_size, dropout_rate, gru, proj, emb, vocab): + super().__init__() + self.emb = emb + self.attn_z = Attn(hidden_size) + self.attn_u = Attn(hidden_size) + self.gru = gru + init_gru(self.gru) + self.proj = proj + self.proj_copy1 = nn.Linear(hidden_size, hidden_size) + self.proj_copy2 = nn.Linear(hidden_size, hidden_size) + self.dropout_rate = dropout_rate + + self.vocab = vocab + +
[docs] def get_sparse_selective_input(self, x_input_np): + result = np.zeros((x_input_np.shape[0], x_input_np.shape[1], cfg.vocab_size + x_input_np.shape[0]), dtype=np.float32) + result.fill(1e-10) + reqs = ['address', 'phone', 'postcode', 'pricerange', 'area'] + for t in range(x_input_np.shape[0] - 1): + for b in range(x_input_np.shape[1]): + w = x_input_np[t][b] + word = self.vocab.decode(w) + if word in reqs: + slot = self.vocab.encode(word + '_SLOT') + result[t + 1][b][slot] = 1.0 + else: + if w == 2 or w >= cfg.vocab_size: + result[t+1][b][cfg.vocab_size + t] = 5.0 + else: + result[t+1][b][w] = 1.0 + result_np = result.transpose((1, 0, 2)) + result = torch.from_numpy(result_np).float() + return result
+ +
[docs] def forward(self, z_enc_out, u_enc_out, u_input_np, m_t_input, degree_input, last_hidden, z_input_np): + sparse_z_input = Variable(self.get_sparse_selective_input(z_input_np), requires_grad=False) + + m_embed = self.emb(m_t_input) + z_context = self.attn_z(last_hidden, z_enc_out) + u_context = self.attn_u(last_hidden, u_enc_out) + gru_in = torch.cat([m_embed, u_context, z_context, degree_input.unsqueeze(0)], dim=2) + gru_out, last_hidden = self.gru(gru_in, last_hidden) + gen_score = self.proj(torch.cat([z_context, u_context, gru_out], 2)).squeeze(0) + z_copy_score = F.tanh(self.proj_copy2(z_enc_out.transpose(0, 1))) + z_copy_score = torch.matmul(z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) + z_copy_score = z_copy_score.cpu() + z_copy_score_max = torch.max(z_copy_score, dim=1, keepdim=True)[0] + z_copy_score = torch.exp(z_copy_score - z_copy_score_max) # [B,T] + z_copy_score = torch.log(torch.bmm(z_copy_score.unsqueeze(1), sparse_z_input)).squeeze( + 1) + z_copy_score_max # [B,V] + z_copy_score = cuda_(z_copy_score) + + scores = F.softmax(torch.cat([gen_score, z_copy_score], dim=1), dim=1) + gen_score, z_copy_score = scores[:, :cfg.vocab_size], \ + scores[:, cfg.vocab_size:] + proba = gen_score + z_copy_score[:, :cfg.vocab_size] # [B,V] + proba = torch.cat([proba, z_copy_score[:, cfg.vocab_size:]], 1) + return proba, last_hidden, gru_out
+ + +
[docs]class TSD(nn.Module): + def __init__(self, embed_size, hidden_size, vocab_size, degree_size, layer_num, dropout_rate, z_length, + max_ts, beam_search=False, teacher_force=100, **kwargs): + super().__init__() + self.vocab = kwargs['vocab'] + + self.emb = nn.Embedding(vocab_size, embed_size) + self.dec_gru = nn.GRU(degree_size + embed_size + hidden_size * 2, hidden_size, dropout=dropout_rate) + self.proj = nn.Linear(hidden_size * 3, vocab_size) + self.u_encoder = SimpleDynamicEncoder(vocab_size, embed_size, hidden_size, layer_num, dropout_rate) + self.z_decoder = BSpanDecoder(embed_size, hidden_size, vocab_size, dropout_rate) + self.m_decoder = ResponseDecoder(embed_size, hidden_size, vocab_size, degree_size, dropout_rate, + self.dec_gru, self.proj, self.emb, self.vocab) + self.embed_size = embed_size + + self.z_length = z_length + self.max_ts = max_ts + self.beam_search = beam_search + self.teacher_force = teacher_force + + self.pr_loss = nn.NLLLoss(ignore_index=0) + self.dec_loss = nn.NLLLoss(ignore_index=0) + + self.saved_log_policy = [] + + if self.beam_search: + self.beam_size = kwargs['beam_size'] + self.eos_token_idx = kwargs['eos_token_idx'] + +
[docs] def forward(self, u_input, u_input_np, m_input, m_input_np, z_input, u_len, m_len, turn_states, + degree_input, mode, **kwargs): + if mode == 'train' or mode == 'valid': + pz_proba, pm_dec_proba, turn_states = \ + self.forward_turn(u_input, u_len, m_input=m_input, m_len=m_len, z_input=z_input, mode='train', + turn_states=turn_states, degree_input=degree_input, u_input_np=u_input_np, + m_input_np=m_input_np, **kwargs) + loss, pr_loss, m_loss = self.supervised_loss(torch.log(pz_proba), torch.log(pm_dec_proba), + z_input, m_input) + return loss, pr_loss, m_loss, turn_states + + elif mode == 'test': + m_output_index, pz_index, turn_states = self.forward_turn(u_input, u_len=u_len, mode='test', + turn_states=turn_states, + degree_input=degree_input, + u_input_np=u_input_np, m_input_np=m_input_np, + **kwargs + ) + return m_output_index, pz_index, turn_states + elif mode == 'rl': + loss = self.forward_turn(u_input, u_len=u_len, is_train=False, mode='rl', + turn_states=turn_states, + degree_input=degree_input, + u_input_np=u_input_np, m_input_np=m_input_np, + **kwargs + ) + return loss
+ +
[docs] def forward_turn(self, u_input, u_len, turn_states, mode, degree_input, u_input_np, m_input_np=None, + m_input=None, m_len=None, z_input=None, **kwargs): + """ + compute required outputs for a single dialogue turn. Turn state{Dict} will be updated in each call. + :param u_input_np: + :param m_input_np: + :param u_len: + :param turn_states: + :param is_train: + :param u_input: [T,B] + :param m_input: [T,B] + :param z_input: [T,B] + :return: + """ + prev_z_input = kwargs.get('prev_z_input', None) + prev_z_input_np = kwargs.get('prev_z_input_np', None) + prev_z_len = kwargs.get('prev_z_len', None) + pv_z_emb = None + batch_size = u_input.size(1) + pv_z_enc_out = None + + if prev_z_input is not None: + pv_z_enc_out, _, pv_z_emb = self.u_encoder(prev_z_input, prev_z_len) + u_enc_out, u_enc_hidden, u_emb = self.u_encoder(u_input, u_len) + last_hidden = u_enc_hidden[:-1] + z_tm1 = cuda_(Variable(torch.ones(1, batch_size).long() * 3)) # GO_2 token + m_tm1 = cuda_(Variable(torch.ones(1, batch_size).long())) # GO token + if mode == 'train': + pz_dec_outs = [] + pz_proba = [] + z_length = z_input.size(0) if z_input is not None else self.z_length # GO token + hiddens = [None] * batch_size + for t in range(z_length): + pz_dec_out, last_hidden, proba = \ + self.z_decoder(u_enc_out=u_enc_out, u_input_np=u_input_np, + z_tm1=z_tm1, last_hidden=last_hidden, + pv_z_enc_out=pv_z_enc_out, prev_z_input_np=prev_z_input_np, + u_emb=u_emb, pv_z_emb=pv_z_emb) + pz_proba.append(proba) + pz_dec_outs.append(pz_dec_out) + z_np = z_tm1.view(-1).cpu().data.numpy() + for i in range(batch_size): + if z_np[i] == self.vocab.encode('EOS_Z2'): + hiddens[i] = last_hidden[:,i,:] + z_tm1 = z_input[t].view(1, -1) + for i in range(batch_size): + if hiddens[i] is None: + hiddens[i] = last_hidden[:,i,:] + last_hidden = torch.stack(hiddens, dim=1) + + z_input_np = z_input.cpu().data.numpy() + + pz_dec_outs = torch.cat(pz_dec_outs, dim=0) # [Tz,B,H] + pz_proba = torch.stack(pz_proba, dim=0) + # P(m|z,u) + pm_dec_proba, m_dec_outs = [], [] + m_length = m_input.size(0) # Tm + #last_hidden = u_enc_hidden[:-1] + for t in range(m_length): + teacher_forcing = toss_(self.teacher_force) + proba, last_hidden, dec_out = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1, + degree_input, last_hidden, z_input_np) + if teacher_forcing: + m_tm1 = m_input[t].view(1, -1) + else: + _, m_tm1 = torch.topk(proba, 1) + m_tm1 = m_tm1.view(1, -1) + pm_dec_proba.append(proba) + m_dec_outs.append(dec_out) + + pm_dec_proba = torch.stack(pm_dec_proba, dim=0) # [T,B,V] + return pz_proba, pm_dec_proba, None + else: + pz_dec_outs, bspan_index,last_hidden = self.bspan_decoder(u_enc_out, z_tm1, last_hidden, u_input_np, + pv_z_enc_out=pv_z_enc_out, prev_z_input_np=prev_z_input_np, + u_emb=u_emb, pv_z_emb=pv_z_emb) + pz_dec_outs = torch.cat(pz_dec_outs, dim=0) + + if mode == 'test': + if degree_input is None: + degree, degree_input = kwargs.get('func')(bspan_index[0]) + else: + degree = None + + if not self.beam_search: + m_output_index = self.greedy_decode(pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, + degree_input, bspan_index) + + else: + m_output_index = self.beam_search_decode(pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, + degree_input, bspan_index) + + return m_output_index, bspan_index, degree + elif mode == 'rl': + return self.sampling_decode(pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, + degree_input, bspan_index)
+ +
[docs] def bspan_decoder(self, u_enc_out, z_tm1, last_hidden, u_input_np, pv_z_enc_out, prev_z_input_np, u_emb, pv_z_emb): + pz_dec_outs = [] + pz_proba = [] + decoded = [] + batch_size = u_enc_out.size(1) + hiddens = [None] * batch_size + for t in range(cfg.z_length): + pz_dec_out, last_hidden, proba = \ + self.z_decoder(u_enc_out=u_enc_out, u_input_np=u_input_np, + z_tm1=z_tm1, last_hidden=last_hidden, pv_z_enc_out=pv_z_enc_out, + prev_z_input_np=prev_z_input_np, u_emb=u_emb, pv_z_emb=pv_z_emb) + pz_proba.append(proba) + pz_dec_outs.append(pz_dec_out) + z_proba, z_index = torch.topk(proba, 1) # [B,1] + z_index = z_index.data.view(-1) + decoded.append(z_index.clone()) + for i in range(z_index.size(0)): + if z_index[i] >= cfg.vocab_size: + z_index[i] = 2 # unk + z_np = z_tm1.view(-1).cpu().data.numpy() + for i in range(batch_size): + if z_np[i] == self.vocab.encode('EOS_Z2'): + hiddens[i] = last_hidden[:, i, :] + z_tm1 = cuda_(Variable(z_index).view(1, -1)) + for i in range(batch_size): + if hiddens[i] is None: + hiddens[i] = last_hidden[:, i, :] + last_hidden = torch.stack(hiddens, dim=1) + decoded = torch.stack(decoded, dim=0).transpose(0, 1) + decoded = list(decoded) + decoded = [list(_) for _ in decoded] + return pz_dec_outs, decoded, last_hidden
+ +
[docs] def greedy_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index): + decoded = [] + bspan_index_np = pad_sequences(bspan_index).transpose((1, 0)) + for t in range(self.max_ts): + proba, last_hidden, _ = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1, + degree_input, last_hidden, bspan_index_np) + proba = torch.cat((proba[:, :2], proba[:, 3:]), 1) + mt_proba, mt_index = torch.topk(proba, 1) # [B,1] + mt_index.add_(mt_index.ge(2).long()) + mt_index = mt_index.data.view(-1) + decoded.append(mt_index.clone()) + for i in range(mt_index.size(0)): + if mt_index[i] >= cfg.vocab_size: + mt_index[i] = 2 # unk + m_tm1 = cuda_(Variable(mt_index).view(1, -1)) + decoded = torch.stack(decoded, dim=0).transpose(0, 1) + decoded = list(decoded) + return [list(_) for _ in decoded]
+ +
[docs] def beam_search_decode_single(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, + bspan_index): + eos_token_id = self.vocab.encode(cfg.eos_m_token) + batch_size = pz_dec_outs.size(1) + if batch_size != 1: + raise ValueError('"Beam search single" requires batch size to be 1') + + class BeamState: + def __init__(self, score, last_hidden, decoded, length): + """ + Beam state in beam decoding + :param score: sum of log-probabilities + :param last_hidden: last hidden + :param decoded: list of *Variable[1*1]* of all decoded words + :param length: current decoded sentence length + """ + self.score = score + self.last_hidden = last_hidden + self.decoded = decoded + self.length = length + + def update_clone(self, score_incre, last_hidden, decoded_t): + decoded = copy.copy(self.decoded) + decoded.append(decoded_t) + clone = BeamState(self.score + score_incre, last_hidden, decoded, self.length + 1) + return clone + + def beam_result_valid(decoded_t, bspan_index): + decoded_t = [_.view(-1).data[0] for _ in decoded_t] + req_slots = self.get_req_slots(bspan_index) + decoded_sentence = self.vocab.sentence_decode(decoded_t, cfg.eos_m_token) + for req in req_slots: + if req not in decoded_sentence: + return False + return True + + def score_bonus(state, decoded, bspan_index): + bonus = cfg.beam_len_bonus + return bonus + + def soft_score_incre(score, turn): + return score + + finished, failed = [], [] + states = [] # sorted by score decreasingly + dead_k = 0 + states.append(BeamState(0, last_hidden, [m_tm1], 0)) + bspan_index_np = np.array(bspan_index).reshape(-1, 1) + for t in range(self.max_ts): + new_states = [] + k = 0 + while k < len(states) and k < self.beam_size - dead_k: + state = states[k] + last_hidden, m_tm1 = state.last_hidden, state.decoded[-1] + proba, last_hidden, _ = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1, degree_input, + last_hidden, bspan_index_np) + + proba = torch.log(proba) + mt_proba, mt_index = torch.topk(proba, self.beam_size - dead_k) # [1,K] + for new_k in range(self.beam_size - dead_k): + score_incre = soft_score_incre(mt_proba[0][new_k].data[0], t) + score_bonus(state, + mt_index[0][new_k].data[0],bspan_index) + if len(new_states) >= self.beam_size - dead_k and state.score + score_incre < new_states[-1].score: + break + decoded_t = mt_index[0][new_k] + if decoded_t.data[0] >= cfg.vocab_size: + decoded_t.data[0] = 2 # unk + if self.vocab.decode(decoded_t.data[0]) == cfg.eos_m_token: + if beam_result_valid(state.decoded, bspan_index): + finished.append(state) + dead_k += 1 + else: + failed.append(state) + else: + decoded_t = decoded_t.view(1, -1) + new_state = state.update_clone(score_incre, last_hidden, decoded_t) + new_states.append(new_state) + + k += 1 + if self.beam_size - dead_k < 0: + break + new_states = new_states[:self.beam_size - dead_k] + new_states.sort(key=lambda x: -x.score) + states = new_states + + if t == self.max_ts - 1 and not finished: + finished = failed + print('FAIL') + if not finished: + finished.append(states[0]) + + finished.sort(key=lambda x: -x.score) + decoded_t = finished[0].decoded + decoded_t = [_.view(-1).data[0] for _ in decoded_t] + decoded_sentence = self.vocab.sentence_decode(decoded_t, cfg.eos_m_token) + print(decoded_sentence) + generated = torch.cat(finished[0].decoded, dim=1).data # [B=1, T] + return generated
+ +
[docs] def beam_search_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index): + vars = torch.split(pz_dec_outs, 1, dim=1), torch.split(u_enc_out, 1, dim=1), torch.split( + m_tm1, 1, dim=1), torch.split(last_hidden, 1, dim=1), torch.split(degree_input, 1, dim=0) + decoded = [] + for i, (pz_dec_out_s, u_enc_out_s, m_tm1_s, last_hidden_s, degree_input_s) in enumerate(zip(*vars)): + decoded_s = self.beam_search_decode_single(pz_dec_out_s, u_enc_out_s, m_tm1_s, + u_input_np[:, i].reshape((-1, 1)), + last_hidden_s, degree_input_s, bspan_index[i]) + decoded.append(decoded_s) + return [list(_.view(-1)) for _ in decoded]
+ +
[docs] def supervised_loss(self, pz_proba, pm_dec_proba, z_input, m_input): + pz_proba, pm_dec_proba = pz_proba[:, :, :cfg.vocab_size].contiguous(), pm_dec_proba[:, :, + :cfg.vocab_size].contiguous() + pr_loss = self.pr_loss(pz_proba.view(-1, pz_proba.size(2)), z_input.view(-1)) + m_loss = self.dec_loss(pm_dec_proba.view(-1, pm_dec_proba.size(2)), m_input.view(-1)) + + loss = pr_loss + m_loss + return loss, pr_loss, m_loss
+ +
[docs] def self_adjust(self, epoch): + pass
+ + # REINFORCEMENT fine-tuning with MC + +
[docs] def get_req_slots(self, bspan_index): + reqs = ['address', 'phone', 'postcode', 'pricerange', 'area'] + reqs = set(self.vocab.sentence_decode(bspan_index).split()).intersection(reqs) + return [_ + '_SLOT' for _ in reqs]
+ +
[docs] def reward(self, m_tm1, decoded, bspan_index): + """ + The setting of the reward function is heuristic. It can be better optimized. + :param m_tm1: + :param decoded: + :param bspan_index: + :return: + """ + req_slots = self.get_req_slots(bspan_index) + all_reqs = ['address', 'phone', 'postcode', 'pricerange', 'area'] + all_reqs = [_ + '_SLOT' for _ in all_reqs] + + m_tm1 = self.vocab.decode(m_tm1[0]) + finished = m_tm1 == 'EOS_M' + decoded = [_.view(-1)[0] for _ in decoded] + decoded_sentence = self.vocab.sentence_decode(decoded, cfg.eos_m_token).split() + reward = 0.0 # -0.1 + ''' + if not finished: + if m_tm1 in req_slots: + if decoded_sentence and m_tm1 not in decoded_sentence[:-1]: + reward = 1.0 + ''' + # some modification for reward function. + if m_tm1 in req_slots: + if decoded_sentence and m_tm1 not in decoded_sentence[:-1]: + reward += 1.5 + else: + reward -= 1.0 # repeat + elif m_tm1 in all_reqs: + if decoded_sentence and m_tm1 not in decoded_sentence[:-1]: + reward += 0.5 + return reward, finished
+ +
[docs] def sampling_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index): + vars = torch.split(pz_dec_outs, 1, dim=1), torch.split(u_enc_out, 1, dim=1), torch.split( + m_tm1, 1, dim=1), torch.split(last_hidden, 1, dim=1), torch.split(degree_input, 1, dim=0) + batch_loss = [] + + sample_num = 1 + + for i, (pz_dec_out_s, u_enc_out_s, m_tm1_s, last_hidden_s, degree_input_s) in enumerate(zip(*vars)): + if not self.get_req_slots(bspan_index[i]): + continue + for j in range(sample_num): + loss = self.sampling_decode_single(pz_dec_out_s, u_enc_out_s, m_tm1_s, u_input_np[:, i].reshape((-1, 1)), + last_hidden_s, degree_input_s, bspan_index[i]) + batch_loss.append(loss) + if not batch_loss: + return None + else: + return sum(batch_loss) / len(batch_loss)
+ +
[docs] def sampling_decode_single(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index): + decoded = [] + reward_sum = 0 + log_probs = [] + rewards = [] + bspan_index_np = np.array(bspan_index).reshape(-1, 1) + for t in range(self.max_ts): + # reward + reward, finished = self.reward(m_tm1.data.view(-1), decoded, bspan_index) + reward_sum += reward + rewards.append(reward) + if t == self.max_ts - 1: + finished = True + if finished: + loss = self.finish_episode(log_probs, rewards) + return loss + # action + proba, last_hidden, _ = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1, + degree_input, last_hidden, bspan_index_np) + proba = proba.squeeze(0) # [B,V] + dis = Categorical(proba) + action = dis.sample() + log_probs.append(dis.log_prob(action)) + mt_index = action.data.view(-1) + decoded.append(mt_index.clone()) + + for i in range(mt_index.size(0)): + if mt_index[i] >= cfg.vocab_size: + mt_index[i] = 2 # unk + + m_tm1 = cuda_(Variable(mt_index).view(1, -1))
+ +
[docs] def finish_episode(self, log_probas, saved_rewards): + R = 0 + policy_loss = [] + rewards = [] + for r in saved_rewards: + R = r + 0.8 * R + rewards.insert(0, R) + + rewards = torch.Tensor(rewards) + # update: we notice improved performance without reward normalization + # rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps) + + for log_prob, reward in zip(log_probas, rewards): + policy_loss.append((-log_prob * reward).unsqueeze(0)) + l = len(policy_loss) + policy_loss = torch.cat(policy_loss).sum() + return policy_loss / l
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlg/multiwoz/evaluate.html b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/evaluate.html new file mode 100644 index 0000000..011cc0e --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/evaluate.html @@ -0,0 +1,334 @@ + + + + + + + + + + + convlab.modules.nlg.multiwoz.evaluate — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlg.multiwoz.evaluate
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlg.multiwoz.evaluate

+"""
+Evaluate NLG models on sys utterances of Multiwoz test dataset
+Metric: dataset level BLEU-4, slot error rate
+Usage: PYTHONPATH=../../../.. python evaluate.py [SCLSTM|MultiwozTemplateNLG]
+"""
+import json
+import random
+import sys
+import zipfile
+
+import numpy
+import numpy as np
+import torch
+from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
+
+from convlab.modules.nlg.multiwoz.multiwoz_template_nlg import MultiwozTemplateNLG
+from convlab.modules.nlg.multiwoz.sc_lstm.nlg_sc_lstm import SCLSTM
+
+seed = 2019
+random.seed(seed)
+numpy.random.seed(seed)
+torch.manual_seed(seed)
+
+
+
[docs]def get_bleu4(dialog_acts, golden_utts, gen_utts): + das2utts = {} + for das, utt, gen in zip(dialog_acts, golden_utts, gen_utts): + utt = utt.lower() + gen = gen.lower() + for da, svs in das.items(): + domain, act = da.split('-') + if act == 'Request' or domain == 'general': + continue + else: + for s, v in sorted(svs, key=lambda x: x[0]): + if s == 'Internet' or s == 'Parking' or s == 'none' or v == 'none': + continue + else: + v = v.lower() + if (' ' + v in utt) or (v + ' ' in utt): + utt = utt.replace(v, '{}-{}'.format(da, s), 1) + if (' ' + v in gen) or (v + ' ' in gen): + gen = gen.replace(v, '{}-{}'.format(da, s), 1) + hash_key = '' + for da in sorted(das.keys()): + for s, v in sorted(das[da], key=lambda x: x[0]): + hash_key += da + '-' + s + ';' + das2utts.setdefault(hash_key, {'refs': [], 'gens': []}) + das2utts[hash_key]['refs'].append(utt) + das2utts[hash_key]['gens'].append(gen) + # pprint(das2utts) + refs, gens = [], [] + for das in das2utts.keys(): + for gen in das2utts[das]['gens']: + refs.append([s.split() for s in das2utts[das]['refs']]) + gens.append(gen.split()) + bleu = corpus_bleu(refs, gens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1) + return bleu
+ + +
[docs]def get_err_slot(dialog_acts, nlg_model): + assert isinstance(nlg_model, SCLSTM) + errs = [] + N_total, p_total, q_total = 0, 0, 0 + for i, das in enumerate(dialog_acts): + print('[%d/%d]'% (i+1,len(dialog_acts))) + gen = nlg_model.generate_slots(das) + triples = [] + counter = {} + for da in das: + if 'Request' in da or 'general' in da: + continue + for s,v in das[da]: + if s == 'Internet' or s == 'Parking' or s == 'none' or v == 'none': + continue + slot = da.lower()+'-'+s.lower() + counter.setdefault(slot,0) + counter[slot] += 1 + triples.append(slot+'-'+str(counter[slot])) + assert len(set(triples))==len(triples) + assert len(set(gen))==len(gen) + N = len(triples) + p = len(set(triples)-set(gen)) + q = len(set(gen)-set(triples)) + # print(triples) + # print(gen) + N_total+=N + p_total+=p + q_total+=q + if N>0: + err = (p+q)*1.0/N + print(err) + errs.append(err) + # else: + # assert q==0 + print('mean(std): {}({})'.format(np.mean(errs),np.std(errs))) + if N_total>0: + print('divide after sum:', (p_total+q_total)/N_total) + return np.mean(errs)
+ + +if __name__ == '__main__': + if len(sys.argv) != 2 : + print("usage:") + print("\t python evaluate.py model_name") + print("\t model_name=SCLSTM or MultiwozTemplateNLG") + sys.exit() + model_name = sys.argv[1] + print("Loading", model_name) + if model_name == 'SCLSTM': + model_sys = SCLSTM(model_file="https://convlab.blob.core.windows.net/models/nlg-sclstm-multiwoz.zip") + elif model_name == 'MultiwozTemplateNLG': + model_sys = MultiwozTemplateNLG(is_user=False) + else: + raise Exception("Available model: SCLSTM, MultiwozTemplateNLG") + + archive = zipfile.ZipFile('../../../../data/multiwoz/test.json.zip', 'r') + test_data = json.load(archive.open('test.json')) + + dialog_acts = [] + golden_utts = [] + gen_utts = [] + + sess_num = 0 + for no, sess in list(test_data.items()): + sess_num+=1 + print('[%d/%d]' % (sess_num, len(test_data))) + for i, turn in enumerate(sess['log']): + if i % 2 == 0: + continue + dialog_acts.append(turn['dialog_act']) + golden_utts.append(turn['text']) + if model_name == 'SCLSTM': + gen_utts.append(model_sys.generate(turn['dialog_act'])) + elif model_name == 'MultiwozTemplateNLG': + gen_utts.append(model_sys.generate(turn['dialog_act'])) + + bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts) + + print("Calculate bleu-4") + print("BLEU-4: %.4f" % bleu4) + + if model_name == 'SCLSTM': + print("Calculate slot error rate:") + err = get_err_slot(dialog_acts, model_sys) + print('ERR:', err) + + print("BLEU-4: %.4f" % bleu4) + print(model_name) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlg/multiwoz/multiwoz_template_nlg/multiwoz_template_nlg.html b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/multiwoz_template_nlg/multiwoz_template_nlg.html new file mode 100644 index 0000000..72a1982 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/multiwoz_template_nlg/multiwoz_template_nlg.html @@ -0,0 +1,385 @@ + + + + + + + + + + + convlab.modules.nlg.multiwoz.multiwoz_template_nlg.multiwoz_template_nlg — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlg.multiwoz.multiwoz_template_nlg.multiwoz_template_nlg
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlg.multiwoz.multiwoz_template_nlg.multiwoz_template_nlg

+"""
+template NLG for multiwoz dataset. templates are in `multiwoz_template_nlg/` dir.
+See `example` function in this file for usage.
+"""
+import json
+import os
+import random
+from pprint import pprint
+
+from convlab.modules.nlg.nlg import NLG
+
+
+
[docs]def read_json(filename): + with open(filename, 'r') as f: + return json.load(f)
+ +# supported slot +slot2word = { + 'Fee': 'fee', + 'Addr': 'address', + 'Area': 'area', + 'Stars': 'stars', + 'Internet': 'Internet', + 'Department': 'department', + 'Choice': 'choice', + 'Ref': 'reference number', + 'Food': 'food', + 'Type': 'type', + 'Price': 'price range', + 'Stay': 'stay', + 'Phone': 'phone', + 'Post': 'postcode', + 'Day': 'day', + 'Name': 'name', + 'Car': 'car type', + 'Leave': 'leave', + 'Time': 'time', + 'Arrive': 'arrive', + 'Ticket': 'ticket', + 'Depart': 'departure', + 'People': 'people', + 'Dest': 'destination', + 'Parking': 'parking', + 'Open': 'open', + 'Id': 'Id', + # 'TrainID': 'TrainID' +} + + +
[docs]class MultiwozTemplateNLG(NLG): + def __init__(self, is_user, mode="manual"): + """ + :param is_user: if dialog_act from user or system + :param mode: `auto`: templates extracted from data without manual modification, may have no match; + `manual`: templates with manual modification, sometimes verbose; + `auto_manual`: use auto templates first. When fails, use manual templates. + both template are dict, *_template[dialog_act][slot] is a list of templates. + """ + super().__init__() + self.is_user = is_user + self.mode = mode + template_dir = os.path.dirname(os.path.abspath(__file__)) + self.auto_user_template = read_json(os.path.join(template_dir, 'auto_user_template_nlg.json')) + self.auto_system_template = read_json(os.path.join(template_dir, 'auto_system_template_nlg.json')) + self.manual_user_template = read_json(os.path.join(template_dir, 'manual_user_template_nlg.json')) + self.manual_system_template = read_json(os.path.join(template_dir, 'manual_system_template_nlg.json')) + +
[docs] def generate(self, dialog_acts): + """ + NLG for Multiwoz dataset + :param dialog_acts: {da1:[[slot1,value1],...], da2:...} + :return: generated sentence + """ + mode = self.mode + try: + is_user = self.is_user + if mode=='manual': + if is_user: + template = self.manual_user_template + else: + template = self.manual_system_template + + return self._manual_generate(dialog_acts, template) + + elif mode=='auto': + if is_user: + template = self.auto_user_template + else: + template = self.auto_system_template + + return self._auto_generate(dialog_acts, template) + + elif mode=='auto_manual': + if is_user: + template1 = self.auto_user_template + template2 = self.manual_user_template + else: + template1 = self.auto_system_template + template2 = self.manual_system_template + + res = self._auto_generate(dialog_acts, template1) + if res == 'None': + res = self._manual_generate(dialog_acts, template2) + return res + + else: + raise Exception("Invalid mode! available mode: auto, manual, auto_manual") + except Exception as e: + print('Error in processing:') + pprint(dialog_acts) + raise e
+ + def _postprocess(self,sen): + sen = sen.strip().capitalize() + if len(sen) > 0 and sen[-1] != '?' and sen[-1] != '.': + sen += '.' + sen += ' ' + return sen + + def _manual_generate(self, dialog_acts, template): + sentences = '' + for dialog_act, slot_value_pairs in dialog_acts.items(): + intent = dialog_act.split('-') + if 'Select'==intent[1]: + slot2values = {} + for slot, value in slot_value_pairs: + slot2values.setdefault(slot, []) + slot2values[slot].append(value) + for slot, values in slot2values.items(): + if slot == 'none': continue + sentence = 'Do you prefer ' + values[0] + for i, value in enumerate(values[1:]): + if i == (len(values) - 2): + sentence += ' or ' + value + else: + sentence += ' , ' + value + sentence += ' {} ? '.format(slot2word[slot]) + sentences += sentence + elif 'Request'==intent[1]: + for slot, value in slot_value_pairs: + if dialog_act not in template or slot not in template[dialog_act]: + sentence = 'What is the {} of {} ? '.format(slot, dialog_act.split('-')[0].lower()) + sentences += sentence + else: + sentence = random.choice(template[dialog_act][slot]) + sentence = self._postprocess(sentence) + sentences += sentence + elif 'general'==intent[0] and dialog_act in template: + sentence = random.choice(template[dialog_act]['none']) + sentence = self._postprocess(sentence) + sentences += sentence + else: + for slot, value in slot_value_pairs: + if dialog_act in template and slot in template[dialog_act]: + sentence = random.choice(template[dialog_act][slot]) + sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value)) + else: + if slot in slot2word: + sentence = 'The {} is {} . '.format(slot2word[slot], str(value)) + else: + sentence = '' + sentence = self._postprocess(sentence) + sentences += sentence + return sentences.strip() + + def _auto_generate(self, dialog_acts, template): + sentences = '' + for dialog_act, slot_value_pairs in dialog_acts.items(): + key = '' + for s, v in sorted(slot_value_pairs, key=lambda x: x[0]): + key += s + ';' + if dialog_act in template and key in template[dialog_act]: + sentence = random.choice(template[dialog_act][key]) + if 'Request' in dialog_act or 'general' in dialog_act: + sentence = self._postprocess(sentence) + sentences += sentence + else: + for s, v in sorted(slot_value_pairs, key=lambda x: x[0]): + if v != 'none': + sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), s.upper()), v, 1) + sentence = self._postprocess(sentence) + sentences += sentence + else: + return 'None' + return sentences.strip()
+ + +
[docs]def example(): + # dialog act + dialog_acts = {} + # whether from user or system + is_user = False + + multiwoz_template_nlg = MultiwozTemplateNLG(is_user) + # print(dialog_acts) + print(multiwoz_template_nlg.generate(dialog_acts))
+ + +if __name__ == '__main__': + example() +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlg/multiwoz/sc_lstm/bleu.html b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/sc_lstm/bleu.html new file mode 100644 index 0000000..65904dc --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/sc_lstm/bleu.html @@ -0,0 +1,364 @@ + + + + + + + + + + + convlab.modules.nlg.multiwoz.sc_lstm.bleu — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlg.multiwoz.sc_lstm.bleu
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlg.multiwoz.sc_lstm.bleu

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import argparse
+import json
+import sys
+import time
+
+from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
+
+
+#def delexicalise(sent,dact): # for domain4
+#	feat = SoftDActFormatter().parse(dact,keepValues=True)
+#	return ExactMatchDataLexicaliser().delexicalise(sent,feat['s2v'])
+#
+#
+#def lexicalise(sent,dact): # for domain4
+#	feat = SoftDActFormatter().parse(dact,keepValues=True)
+#	return ExactMatchDataLexicaliser().lexicalise(sent,feat['s2v'])
+#
+#
+#def parse_sr(sr, domain): # for domain4
+#	'''
+#	input da: 'inform(name=piperade;goodformeal=dinner;food=basque)'
+#	return  : a str 'domain|da|slot1, slot2, ...'
+#	Note: cannot deal with repeat slots, e.g. slot_name*2 will has the same sr as slot_name*1
+#	'''
+#	da = sr.split('(')[0]
+#	_sr = sr.split('(')[1].split(')')[0].split(';')
+#	slots = []
+#	for sv in _sr:
+#		slots.append(sv.split('=')[0])
+#	slots = sorted(slots)
+#
+#	res = domain + '|' + da + '|'
+#	for slot in slots:
+#		res  += (slot+',')
+#	res = (res[:-1]) # remove last ,
+#	return res
+#
+#
+
+# def score_domain4(res_file):
+# 	# parse test set to have semantic representation of each target
+# 	target2sr = {} # target sentence to a defined str of sr
+# 	sr2content = {}
+# 	domains = ['restaurant', 'hotel', 'tv', 'laptop']
+# 	repeat_count = 0
+# 	for domain in domains:
+# 		with open('data/domain4/original/'+domain+'/test.json') as f:
+# 			for i in range(5):
+# 				f.readline()
+# 			data = json.load(f)
+#
+# 		for sr, target, base in data:
+# 			target = delexicalise( normalize(re.sub(' [\.\?\!]$','',target)),sr)
+# 			target = lexicalise(target, sr)
+#
+# 			sr = parse_sr(sr, domain)
+# 			if target in target2sr:
+# 				repeat_count += 1
+# 				continue
+# 			if target[-1] == ' ':
+# 				target = target[:-1]
+# 			target2sr[target] = sr
+#
+# 			if sr not in sr2content:
+# 				sr2content[sr] = [[], [], []] # [ [refs], [bases], [gens] ]
+#
+# 	with open(res_file) as f:
+# 		for line in f:
+# 			if 'Target' in line:
+# 				target = line.strip().split(':')[1][1:]
+# 				sr = target2sr[target]
+# 				sr2content[sr][0].append(target)
+#
+# 			if 'Base' in line:
+# 				base = line.strip().split(':')[1][1:]
+# 				if base[-1] == ' ':
+# 					base = base[:-1]
+# 				sr2content[sr][1].append(base)
+#
+# 			if 'Gen' in line:
+# 				gen = line.strip().split(':')[1][1:]
+# 				sr2content[sr][2].append(gen)
+#
+# 	return sr2content
+
+
+
[docs]def score_woz(res_file, ignore=False): + #corpus = [] + feat2content = {} + with open(res_file) as f: + for line in f: + if 'Feat' in line: + feat = line.strip().split(':')[1][1:] + + if feat not in feat2content: + feat2content[feat] = [[], [], []] # [ [refs], [bases], [gens] ] + continue + + if 'Target' in line: + target = line.strip().split(':')[1][1:] + if feat in feat2content: + feat2content[feat][0].append(target) + + if 'Base' in line: + base = line.strip().split(':')[1][1:] + if base[-1] == ' ': + base = base[:-1] + if feat in feat2content: + feat2content[feat][1].append(base) + + if 'Gen' in line: + gen = line.strip().split(':')[1][1:] + if feat in feat2content: + feat2content[feat][2].append(gen) + + return feat2content
+ +
[docs]def get_bleu(feat2content, template=False, ignore=False): + test_type = 'base' if template else 'gen' + print('Start', test_type, file=sys.stderr) + + gen_count = 0 + list_of_references, hypotheses = {'gen': [], 'base': []}, {'gen': [], 'base': []} + for feat in feat2content: + refs, bases, gens = feat2content[feat] + gen_count += len(gens) + refs = [s.split() for s in refs] + + for gen in gens: + gen = gen.split() + list_of_references['gen'].append(refs) + hypotheses['gen'].append(gen) + + for base in bases: + base = base.split() + list_of_references['base'].append(refs) + hypotheses['base'].append(base) + + + print('TEST TYPE:', test_type) + print('Ignore General Acts:', ignore) + smooth = SmoothingFunction() + print('Calculating BLEU...', file=sys.stderr) + print( 'Avg # feat:', len(feat2content) ) + print( 'Avg # gen: {:.2f}'.format(gen_count / len(feat2content)) ) + BLEU = [] + weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.333, 0.333, 0.333, 0), (0.25, 0.25, 0.25, 0.25)] + for i in range(4): + if i == 0 or i == 1 or i == 2: + continue + t = time.time() + bleu = corpus_bleu(list_of_references[test_type], hypotheses[test_type], weights=weights[i], smoothing_function=smooth.method1) + BLEU.append(bleu) + print('Done BLEU-{}, time:{:.1f}'.format(i+1, time.time()-t)) + print('BLEU 1-4:', BLEU) + print('BLEU 1-4:', BLEU, file=sys.stderr) + print('Done', test_type, file=sys.stderr) + print('-----------------------------------')
+ + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Train dialogue generator') + parser.add_argument('--res_file', type=str, help='result file') + parser.add_argument('--dataset', type=str, default='woz', help='result file') + parser.add_argument('--template', type=bool, default=False, help='test on template-based words') + parser.add_argument('--ignore', type=bool, default=False, help='whether to ignore general acts, e.g. bye') + args = parser.parse_args() + assert args.dataset == 'woz' or args.dataset == 'domain4' + if args.dataset == 'woz': + assert args.template is False + feat2content = score_woz(args.res_file, ignore=args.ignore) + else: # domain4 + assert NotImplementedError + # assert args.ignore is False + # feat2content = score_domain4(args.res_file) + get_bleu(feat2content, template=args.template, ignore=args.ignore) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlg/multiwoz/sc_lstm/nlg_sc_lstm.html b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/sc_lstm/nlg_sc_lstm.html new file mode 100644 index 0000000..be029ba --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/sc_lstm/nlg_sc_lstm.html @@ -0,0 +1,383 @@ + + + + + + + + + + + convlab.modules.nlg.multiwoz.sc_lstm.nlg_sc_lstm — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlg.multiwoz.sc_lstm.nlg_sc_lstm
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlg.multiwoz.sc_lstm.nlg_sc_lstm

+# -*- coding: utf-8 -*-
+
+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+import configparser
+import os
+import zipfile
+from copy import deepcopy
+
+import torch
+
+from convlab.lib.file_util import cached_path
+from convlab.modules.nlg.multiwoz.sc_lstm.loader.dataset_woz import SimpleDatasetWoz
+from convlab.modules.nlg.multiwoz.sc_lstm.model.lm_deep import LMDeep
+from convlab.modules.nlg.nlg import NLG
+
+DEFAULT_DIRECTORY = "models"
+DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "nlg-sclstm-multiwoz.zip")
+
+
[docs]def parse(is_user): + if is_user: + args = { + 'model_path': 'sclstm_usr.pt', + 'n_layer': 1, + 'beam_size': 10 + } + else: + args = { + 'model_path': 'sclstm.pt', + 'n_layer': 1, + 'beam_size': 10 + } + + config = configparser.ConfigParser() + if is_user: + config.read(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config/config_usr.cfg')) + else: + config.read(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config/config.cfg')) + config.set('DATA', 'dir', os.path.dirname(os.path.abspath(__file__))) + + return args, config
+ + +
[docs]class SCLSTM(NLG): + def __init__(self, + archive_file=DEFAULT_ARCHIVE_FILE, + use_cuda=False, + is_user=False, + model_file=None): + + if not os.path.isfile(archive_file): + if not model_file: + raise Exception("No model for SC-LSTM is specified!") + archive_file = cached_path(model_file) + model_dir = os.path.dirname(os.path.abspath(__file__)) + if not os.path.exists(os.path.join(model_dir, 'resource')): + archive = zipfile.ZipFile(archive_file, 'r') + archive.extractall(model_dir) + + self.USE_CUDA = use_cuda + self.args, self.config = parse(is_user) + self.dataset = SimpleDatasetWoz(self.config) + + # get model hyper-parameters + hidden_size = self.config.getint('MODEL', 'hidden_size') + + # get feat size + d_size = self.dataset.do_size + self.dataset.da_size + self.dataset.sv_size # len of 1-hot feat + vocab_size = len(self.dataset.word2index) + + self.model = LMDeep('sclstm', vocab_size, vocab_size, hidden_size, d_size, n_layer=self.args['n_layer'], use_cuda=use_cuda) + model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), self.args['model_path']) + # print(model_path) + assert os.path.isfile(model_path) + self.model.load_state_dict(torch.load(model_path)) + self.model.eval() + if use_cuda: + self.model.cuda() + +
[docs] def generate_delex(self, meta): + """ + meta = {"Attraction-Inform": [["Choice","many"],["Area","centre of town"]], + "Attraction-Select": [["Type","church"],["Type"," swimming"],["Type"," park"]]} + """ + # add placeholder value + for k, v in meta.items(): + domain, intent = k.split('-') + if intent == "Request": + for pair in v: + if not isinstance(pair[1], str): + pair[1] = str(pair[1]) + pair.insert(1, '?') + else: + counter = {} + for pair in v: + if not isinstance(pair[1], str): + pair[1] = str(pair[1]) + if pair[0] == 'Internet' or pair[0] == 'Parking': + pair.insert(1, 'yes') + elif pair[0] == 'none': + pair.insert(1, 'none') + else: + if pair[0] in counter: + counter[pair[0]] += 1 + else: + counter[pair[0]] = 1 + pair.insert(1, str(counter[pair[0]])) + + # remove invalid dialog act + meta_ = deepcopy(meta) + for k, v in meta.items(): + for triple in v: + voc = 'd-a-s-v:' + k + '-' + triple[0] + '-' + triple[1] + if voc not in self.dataset.cardinality: + meta_[k].remove(triple) + if not meta_[k]: + del (meta_[k]) + meta = meta_ + + # mapping the inputs + do_idx, da_idx, sv_idx, featStr = self.dataset.getFeatIdx(meta) + do_cond = [1 if i in do_idx else 0 for i in range(self.dataset.do_size)] # domain condition + da_cond = [1 if i in da_idx else 0 for i in range(self.dataset.da_size)] # dial act condition + sv_cond = [1 if i in sv_idx else 0 for i in range(self.dataset.sv_size)] # slot/value condition + feats = [do_cond + da_cond + sv_cond] + + feats_var = torch.FloatTensor(feats) + if self.USE_CUDA: + feats_var = feats_var.cuda() + + decoded_words = self.model.generate(self.dataset, feats_var, self.args['beam_size']) + delex = decoded_words[0] # (beam_size) + + return delex
+ +
[docs] def generate_slots(self, meta): + meta = deepcopy(meta) + + delex = self.generate_delex(meta) + # get all informable or requestable slots + slots = [] + for sen in delex: + slot = [] + counter = {} + words = sen.split() + for word in words: + if word.startswith('slot-'): + placeholder = word[5:] + if placeholder not in counter: + counter[placeholder] = 1 + else: + counter[placeholder] += 1 + slot.append(placeholder+'-'+str(counter[placeholder])) + slots.append(slot) + + # for i in range(self.args.beam_size): + # print(i, slots[i]) + + return slots[0]
+ +
[docs] def generate(self, meta): + meta = deepcopy(meta) + + delex = self.generate_delex(meta) + + # replace the placeholder with entities + recover = [] + for sen in delex: + counter = {} + words = sen.split() + for word in words: + if word.startswith('slot-'): + flag = True + _, domain, intent, slot_type = word.split('-') + da = domain.capitalize() + '-' + intent.capitalize() + if da in meta: + key = da + '-' + slot_type.capitalize() + for pair in meta[da]: + if (pair[0].lower() == slot_type) and ( + (key not in counter) or (counter[key] == int(pair[1]) - 1)): + sen = sen.replace(word, pair[2], 1) + counter[key] = int(pair[1]) + flag = False + break + if flag: + sen = sen.replace(word, '', 1) + recover.append(sen) + + # print('meta', meta) + # for i in range(self.args.beam_size): + # print(i, delex[i]) + # print(i, recover[i]) + + return recover[0]
+ +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlg/multiwoz/template_nlg.html b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/template_nlg.html new file mode 100644 index 0000000..3882ca1 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/template_nlg.html @@ -0,0 +1,221 @@ + + + + + + + + + + + convlab.modules.nlg.multiwoz.template_nlg — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlg.multiwoz.template_nlg
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlg.multiwoz.template_nlg

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from convlab.modules.nlg.nlg import NLG
+
+
[docs]class TemplateNLG(NLG): +
[docs] def init(self,): + NLG.__init__(self)
+ +
[docs] def generate(self, dialog_act): + phrases = [] + for da in dialog_act.keys(): + domain, type = da.split('-') + if domain == 'general': + if type == 'hello': + phrases.append('hello, i need help') + else: + phrases.append('bye') + elif type == 'Request': + for slot, value in dialog_act[da]: + phrases.append('what is the {}'.format(slot)) + else: + for slot, value in dialog_act[da]: + phrases.append('i want the {} to be {}'.format(slot, value)) + sent = ', '.join(phrases) + return sent
+ + +if __name__ == '__main__': + nlg = TemplateNLG() + user_acts = [{"Restaurant-Inform": [["Food", "japanese"], ["Time", "17:45"]]}, + {"Restaurant-Request": [["Price", "?"]]}, + {"general-bye": [["none", "none"]]}] + for ua in user_acts: + sent = nlg.generate(ua) + print(sent) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlg/multiwoz/utils.html b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/utils.html new file mode 100644 index 0000000..8b9a941 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlg/multiwoz/utils.html @@ -0,0 +1,206 @@ + + + + + + + + + + + convlab.modules.nlg.multiwoz.utils — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlg.multiwoz.utils
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlg.multiwoz.utils

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+'''
+'''
+
+import math
+
+import numpy as np
+
+
[docs]def initWeights(n,d): + """ Initialization Strategy """ + #scale_factor = 0.1 + scale_factor = math.sqrt(float(6)/(n + d)) + return (np.random.rand(n,d)*2-1)*scale_factor
+ +
[docs]def mergeDicts(d0, d1): + """ for all k in d0, d0 += d1 . d's are dictionaries of key -> numpy array """ + for k in d1: + if k in d0: d0[k] += d1[k] + else: d0[k] = d1[k]
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlg/nlg.html b/docs/build/html/_modules/convlab/modules/nlg/nlg.html new file mode 100644 index 0000000..d9447fb --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlg/nlg.html @@ -0,0 +1,207 @@ + + + + + + + + + + + convlab.modules.nlg.nlg — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.modules.nlg.nlg

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+
+
[docs]class NLG: + """Base class for NLG model.""" + def __init__(self): + """ Constructor for NLG class. """ + pass + +
[docs] def generate(self, dialog_act): + """ + Generate a natural language utterance conditioned on the dialog act produced by Agenda or Policy. + Args: + dialog_act (dict): The dialog act of the following system response. The dialog act can be either produced + by user agenda or system policy module. + Returns: + response (str): The natural language utterance of the input dialog_act. + """ + pass
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/error.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/error.html new file mode 100644 index 0000000..b3ca44a --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/error.html @@ -0,0 +1,221 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.error — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.error
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.error

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+
+
[docs]class ErrorNLU: + """Base model for generating NLU error.""" + def __init__(self, act_type_rate=0.0, slot_rate=0.0): + """ + Args: + act_type_rate (float): The error rate applied on dialog act type. + slot_rate (float): Error rate applied on slots. + """ + self.set_error_rate(act_type_rate, slot_rate) + +
[docs] def set_error_rate(self, act_type_rate, slot_rate): + """ + Set error rate parameter for error model. + Args: + act_type_rate (float): The error rate applied on dialog act type. + slot_rate (float): Error rate applied on slots. + """ + self.act_type_rate = act_type_rate + self.slot_rate = slot_rate
+ +
[docs] def apply(self, dialog_act): + """ + Apply the error model on dialog act. + Args: + dialog_act (tuple): Dialog act. + Returns: + dialog_act (tuple): Dialog act with noise. + """ + #TODO + return
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/evaluate.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/evaluate.html new file mode 100644 index 0000000..1a17e8e --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/evaluate.html @@ -0,0 +1,268 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.evaluate — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.evaluate
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.evaluate

+"""
+Evaluate NLU models on Multiwoz test dataset
+Metric: dataset level Precision/Recall/F1
+Usage: PYTHONPATH=../../../.. python evaluate.py [OneNetLU|MILU|SVMNLU]
+"""
+import json
+import random
+import sys
+import zipfile
+
+import numpy
+import torch
+
+from convlab.modules.nlu.multiwoz import MILU
+from convlab.modules.nlu.multiwoz import OneNetLU
+from convlab.modules.nlu.multiwoz import SVMNLU
+
+seed = 2019
+random.seed(seed)
+numpy.random.seed(seed)
+torch.manual_seed(seed)
+
+
+
[docs]def da2triples(dialog_act): + triples = [] + for intent, svs in dialog_act.items(): + for slot, value in svs: + triples.append((intent, slot, value)) + return triples
+ + +if __name__ == '__main__': + if len(sys.argv) != 2 : + print("usage:") + print("\t python evaluate.py model_name") + print("\t model_name=OneNetLU, MILU, or SVMNLU") + sys.exit() + model_name = sys.argv[1] + if model_name == 'OneNetLU': + model = OneNetLU(model_file="https://convlab.blob.core.windows.net/models/onenet.tar.gz") + elif model_name == 'MILU': + model = MILU(model_file="https://convlab.blob.core.windows.net/models/milu.tar.gz") + elif model_name == 'SVMNLU': + model = SVMNLU(model_file="https://convlab.blob.core.windows.net/models/svm_multiwoz.zip") + else: + raise Exception("Available model: OneNetLU, MILU, SVMNLU") + + archive = zipfile.ZipFile('../../../../data/multiwoz/test.json.zip', 'r') + test_data = json.load(archive.open('test.json')) + TP, FP, FN = 0, 0, 0 + sen_num = 0 + sess_num = 0 + for no, session in test_data.items(): + sen_num += len(session['log']) + sess_num += 1 + if sess_num%10==0: + print('Session [%d|%d]' % (sess_num, len(test_data))) + precision = 1.0 * TP / (TP + FP) + recall = 1.0 * TP / (TP + FN) + F1 = 2.0 * precision * recall / (precision + recall) + print('Model {} on {} session {} sentences:'.format(model_name, sess_num, sen_num)) + print('\t Precision: %.2f' % (100 * precision)) + print('\t Recall: %.2f' % (100 * recall)) + print('\t F1: %.2f' % (100 * F1)) + for i, turn in enumerate(session['log']): + labels = da2triples(turn['dialog_act']) + predicts = da2triples(model.parse(turn['text'])) + for triple in predicts: + if triple in labels: + TP += 1 + else: + FP += 1 + for triple in labels: + if triple not in predicts: + FN += 1 + print(TP,FP,FN) + precision = 1.0 * TP / (TP + FP) + recall = 1.0 * TP / (TP + FN) + F1 = 2.0 * precision * recall / (precision + recall) + print('Model {} on {} session {} sentences:'.format(model_name,len(test_data),sen_num)) + print('\t Precision: %.2f' % (100 * precision)) + print('\t Recall: %.2f' % (100 * recall)) + print('\t F1: %.2f' % (100 * F1)) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/dai_f1_measure.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/dai_f1_measure.html new file mode 100644 index 0000000..0feba2c --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/dai_f1_measure.html @@ -0,0 +1,269 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.milu.dai_f1_measure — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.milu.dai_f1_measure
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.milu.dai_f1_measure

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from typing import Dict, List, Any
+
+from allennlp.training.metrics.metric import Metric
+
+
+
[docs]class DialogActItemF1Measure(Metric): + """ + """ + def __init__(self) -> None: + """ + Parameters + ---------- + """ + # These will hold per label span counts. + self._true_positives = 0 + self._false_positives = 0 + self._false_negatives = 0 + + + def __call__(self, + predictions: List[Dict[str, Any]], + gold_labels: List[Dict[str, Any]]): + """ + Parameters + ---------- + predictions : ``torch.Tensor``, required. + A tensor of predictions of shape (batch_size, sequence_length, num_classes). + gold_labels : ``torch.Tensor``, required. + A tensor of integer class label of shape (batch_size, sequence_length). It must be the same + shape as the ``predictions`` tensor without the ``num_classes`` dimension. + """ + for prediction, gold_label in zip(predictions, gold_labels): + for dat in prediction: + for sv in prediction[dat]: + if dat not in gold_label or sv not in gold_label[dat]: + self._false_positives += 1 + else: + self._true_positives += 1 + for dat in gold_label: + for sv in gold_label[dat]: + if dat not in prediction or sv not in prediction[dat]: + self._false_negatives += 1 + + +
[docs] def get_metric(self, reset: bool = False): + """ + Returns + ------- + A Dict per label containing following the span based metrics: + precision : float + recall : float + f1-measure : float + + Additionally, an ``overall`` key is included, which provides the precision, + recall and f1-measure for all spans. + """ + # Compute the precision, recall and f1 for all spans jointly. + precision, recall, f1_measure = self._compute_metrics(self._true_positives, + self._false_positives, + self._false_negatives) + metrics = {} + metrics["precision"] = precision + metrics["recall"] = recall + metrics["f1-measure"] = f1_measure + if reset: + self.reset() + return metrics
+ + + @staticmethod + def _compute_metrics(true_positives: int, false_positives: int, false_negatives: int): + precision = float(true_positives) / float(true_positives + false_positives + 1e-13) + recall = float(true_positives) / float(true_positives + false_negatives + 1e-13) + f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13)) + return precision, recall, f1_measure + + +
[docs] def reset(self): + self._true_positives = 0 + self._false_positives = 0 + self._false_negatives = 0
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/dataset_reader.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/dataset_reader.html new file mode 100644 index 0000000..64ea194 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/dataset_reader.html @@ -0,0 +1,334 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.milu.dataset_reader — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.milu.dataset_reader
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.milu.dataset_reader

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import json
+import logging
+import os
+import random
+import zipfile
+from typing import Dict, List, Any
+
+from allennlp.data.dataset_readers.dataset_reader import DatasetReader
+from allennlp.data.fields import TextField, SequenceLabelField, MultiLabelField, MetadataField, Field
+from allennlp.data.instance import Instance
+from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
+from allennlp.data.tokenizers import Token
+from overrides import overrides
+
+from convlab.lib.file_util import cached_path
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+
[docs]@DatasetReader.register("milu") +class MILUDatasetReader(DatasetReader): + """ + Reads instances from a pretokenised file where each line is in the following format: + + WORD###TAG [TAB] WORD###TAG [TAB] ..... \n + + and converts it into a ``Dataset`` suitable for sequence tagging. You can also specify + alternative delimiters in the constructor. + + Parameters + ---------- + word_tag_delimiter: ``str``, optional (default=``"###"``) + The text that separates each WORD from its TAG. + token_delimiter: ``str``, optional (default=``None``) + The text that separates each WORD-TAG pair from the next pair. If ``None`` + then the line will just be split on whitespace. + token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) + We use this to define the input representation for the text. See :class:`TokenIndexer`. + Note that the `output` tags will always correspond to single token IDs based on how they + are pre-tokenised in the data file. + """ + def __init__(self, + context_size: int = 0, + agent: str = None, + random_context_size: bool = True, + token_delimiter: str = None, + token_indexers: Dict[str, TokenIndexer] = None, + lazy: bool = False) -> None: + super().__init__(lazy) + self._context_size = context_size + self._agent = agent + self._random_context_size = random_context_size + self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} + self._token_delimiter = token_delimiter + + @overrides + def _read(self, file_path): + # if `file_path` is a URL, redirect to the cache + file_path = cached_path(file_path) + + if file_path.endswith("zip"): + archive = zipfile.ZipFile(file_path, "r") + data_file = archive.open(os.path.basename(file_path)[:-4]) + else: + data_file = open(file_path, "r") + + logger.info("Reading instances from lines in file at: %s", file_path) + + dialogs = json.load(data_file) + + for dial_name in dialogs: + dialog = dialogs[dial_name]["log"] + context_tokens_list = [] + for i, turn in enumerate(dialog): + tokens = turn["text"].split() + + dialog_act = {} + for dacts in turn["span_info"]: + if dacts[0] not in dialog_act: + dialog_act[dacts[0]] = [] + dialog_act[dacts[0]].append([dacts[1], " ".join(tokens[dacts[3]: dacts[4]+1])]) + + spans = turn["span_info"] + tags = [] + for i in range(len(tokens)): + for span in spans: + if i == span[3]: + tags.append("B-"+span[0]+"+"+span[1]) + break + if i > span[3] and i <= span[4]: + tags.append("I-"+span[0]+"+"+span[1]) + break + else: + tags.append("O") + + intents = [] + for dacts in turn["dialog_act"]: + for dact in turn["dialog_act"][dacts]: + if dacts not in dialog_act or dact[0] not in [sv[0] for sv in dialog_act[dacts]]: + if dact[1] in ["none", "?", "yes", "no", "do nt care", "do n't care"]: + intents.append(dacts+"+"+dact[0]+"*"+dact[1]) + + for dacts in turn["dialog_act"]: + for dact in turn["dialog_act"][dacts]: + if dacts not in dialog_act: + dialog_act[dacts] = turn["dialog_act"][dacts] + break + elif dact[0] not in [sv[0] for sv in dialog_act[dacts]]: + dialog_act[dacts].append(dact) + + num_context = random.randint(0, self._context_size) if self._random_context_size else self._context_size + if len(context_tokens_list) > 0 and num_context > 0: + wrapped_context_tokens = [Token(token) for context_tokens in context_tokens_list[-num_context:] for token in context_tokens] + else: + wrapped_context_tokens = [Token("SENT_END")] + wrapped_tokens = [Token(token) for token in tokens] + context_tokens_list.append(tokens + ["SENT_END"]) + + if self._agent and self._agent == "user" and i % 2 != 1: + continue + if self._agent and self._agent == "system" and i % 2 != 0: + continue + yield self.text_to_instance(wrapped_context_tokens, wrapped_tokens, tags, intents, dialog_act) + + +
[docs] def text_to_instance(self, context_tokens: List[Token], tokens: List[Token], tags: List[str] = None, + intents: List[str] = None, dialog_act: Dict[str, Any] = None) -> Instance: # type: ignore + """ + We take `pre-tokenized` input here, because we don't have a tokenizer in this class. + """ + # pylint: disable=arguments-differ + fields: Dict[str, Field] = {} + # print([t.text for t in context_tokens]) + fields["context_tokens"] = TextField(context_tokens, self._token_indexers) + fields["tokens"] = TextField(tokens, self._token_indexers) + fields["metadata"] = MetadataField({"words": [x.text for x in tokens]}) + if tags is not None: + fields["tags"] = SequenceLabelField(tags, fields["tokens"]) + if intents is not None: + fields["intents"] = MultiLabelField(intents, label_namespace="intent_labels") + if dialog_act is not None: + fields["metadata"] = MetadataField({"words": [x.text for x in tokens], + 'dialog_act': dialog_act}) + else: + fields["metadata"] = MetadataField({"words": [x.text for x in tokens], 'dialog_act': {}}) + return Instance(fields)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/evaluate.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/evaluate.html new file mode 100644 index 0000000..83ba047 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/evaluate.html @@ -0,0 +1,309 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.milu.evaluate — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.milu.evaluate
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.milu.evaluate

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+The ``evaluate`` subcommand can be used to
+evaluate a trained model against a dataset
+and report any metrics calculated by the model.
+"""
+import argparse
+import json
+import logging
+from typing import Dict, Any
+
+from allennlp.common import Params
+from allennlp.common.util import prepare_environment
+from allennlp.data.dataset_readers.dataset_reader import DatasetReader
+from allennlp.data.iterators import DataIterator
+from allennlp.models.archival import load_archive
+from allennlp.training.util import evaluate
+
+from convlab.modules.nlu.multiwoz.milu import dataset_reader, model 
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+argparser = argparse.ArgumentParser(description="Evaluate the specified model + dataset.")
+argparser.add_argument('archive_file', type=str, help='path to an archived trained model')
+
+argparser.add_argument('input_file', type=str, help='path to the file containing the evaluation data')
+
+argparser.add_argument('--output-file', type=str, help='path to output file')
+
+argparser.add_argument('--weights-file',
+                        type=str,
+                        help='a path that overrides which weights file to use')
+
+cuda_device = argparser.add_mutually_exclusive_group(required=False)
+cuda_device.add_argument('--cuda-device',
+                            type=int,
+                            default=-1,
+                            help='id of GPU to use (if any)')
+
+argparser.add_argument('-o', '--overrides',
+                        type=str,
+                        default="",
+                        help='a JSON structure used to override the experiment configuration')
+
+argparser.add_argument('--batch-weight-key',
+                        type=str,
+                        default="",
+                        help='If non-empty, name of metric used to weight the loss on a per-batch basis.')
+
+argparser.add_argument('--extend-vocab',
+                        action='store_true',
+                        default=False,
+                        help='if specified, we will use the instances in your new dataset to '
+                            'extend your vocabulary. If pretrained-file was used to initialize '
+                            'embedding layers, you may also need to pass --embedding-sources-mapping.')
+
+argparser.add_argument('--embedding-sources-mapping',
+                        type=str,
+                        default="",
+                        help='a JSON dict defining mapping from embedding module path to embedding'
+                        'pretrained-file used during training. If not passed, and embedding needs to be '
+                        'extended, we will try to use the original file paths used during training. If '
+                        'they are not available we will use random vectors for embedding extension.')
+
+
+
[docs]def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: + # Disable some of the more verbose logging statements + logging.getLogger('allennlp.common.params').disabled = True + logging.getLogger('allennlp.nn.initializers').disabled = True + logging.getLogger('allennlp.modules.token_embedders.embedding').setLevel(logging.INFO) + + # Load from archive + archive = load_archive(args.archive_file, args.cuda_device, args.overrides, args.weights_file) + config = archive.config + prepare_environment(config) + model = archive.model + model.eval() + + # Load the evaluation data + + # Try to use the validation dataset reader if there is one - otherwise fall back + # to the default dataset_reader used for both training and validation. + validation_dataset_reader_params = config.pop('validation_dataset_reader', None) + if validation_dataset_reader_params is not None: + dataset_reader = DatasetReader.from_params(validation_dataset_reader_params) + else: + dataset_reader = DatasetReader.from_params(config.pop('dataset_reader')) + evaluation_data_path = args.input_file + logger.info("Reading evaluation data from %s", evaluation_data_path) + instances = dataset_reader.read(evaluation_data_path) + + embedding_sources: Dict[str, str] = (json.loads(args.embedding_sources_mapping) + if args.embedding_sources_mapping else {}) + if args.extend_vocab: + logger.info("Vocabulary is being extended with test instances.") + model.vocab.extend_from_instances(Params({}), instances=instances) + model.extend_embedder_vocab(embedding_sources) + + iterator_params = config.pop("validation_iterator", None) + if iterator_params is None: + iterator_params = config.pop("iterator") + iterator = DataIterator.from_params(iterator_params) + iterator.index_with(model.vocab) + + metrics = evaluate(model, instances, iterator, args.cuda_device, args.batch_weight_key) + + logger.info("Finished evaluating.") + logger.info("Metrics:") + for key, metric in metrics.items(): + logger.info("%s: %s", key, metric) + + output_file = args.output_file + if output_file: + with open(output_file, "w") as file: + json.dump(metrics, file, indent=4) + return metrics
+ + +if __name__ == "__main__": + args = argparser.parse_args() + evaluate_from_args(args) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/model.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/model.html new file mode 100644 index 0000000..466560b --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/model.html @@ -0,0 +1,579 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.milu.model — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.milu.model
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.milu.model

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from math import log10
+from typing import Dict, Optional, List, Any
+
+import allennlp.nn.util as util
+import numpy as np
+import torch
+from allennlp.common.checks import check_dimensions_match, ConfigurationError
+from allennlp.data import Vocabulary
+from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans
+from allennlp.models.model import Model
+from allennlp.modules import Attention, ConditionalRandomField, FeedForward
+from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
+from allennlp.modules.attention import LegacyAttention
+from allennlp.modules.conditional_random_field import allowed_transitions
+from allennlp.modules.similarity_functions import SimilarityFunction
+from allennlp.nn import InitializerApplicator, RegularizerApplicator
+from allennlp.nn.util import sequence_cross_entropy_with_logits
+from allennlp.training.metrics import SpanBasedF1Measure
+from overrides import overrides
+from torch.nn.modules.linear import Linear
+
+from convlab.modules.nlu.multiwoz.milu.dai_f1_measure import DialogActItemF1Measure
+from convlab.modules.nlu.multiwoz.milu.multilabel_f1_measure import MultiLabelF1Measure
+
+
+
[docs]@Model.register("milu") +class MILU(Model): + """ + The ``MILU`` encodes a sequence of text with a ``Seq2SeqEncoder``, + then performs multi-label classification for closed-class dialog act items and + sequence labeling to predict a tag for each token in the sequence. + + Parameters + ---------- + initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) + Used to initialize the model parameters. + regularizer : ``RegularizerApplicator``, optional (default=``None``) + If provided, will be used to calculate the regularization penalty during training. + """ + + def __init__(self, vocab: Vocabulary, + text_field_embedder: TextFieldEmbedder, + encoder: Seq2SeqEncoder, + intent_encoder: Seq2SeqEncoder = None, + tag_encoder: Seq2SeqEncoder = None, + attention: Attention = None, + attention_function: SimilarityFunction = None, + context_for_intent: bool = True, + context_for_tag: bool = True, + attention_for_intent: bool = True, + attention_for_tag: bool = True, + sequence_label_namespace: str = "labels", + intent_label_namespace: str = "intent_labels", + feedforward: Optional[FeedForward] = None, + label_encoding: Optional[str] = None, + include_start_end_transitions: bool = True, + crf_decoding: bool = False, + constrain_crf_decoding: bool = None, + focal_loss_gamma: float = None, + nongeneral_intent_weight: float = 5., + num_train_examples: float = None, + calculate_span_f1: bool = None, + dropout: Optional[float] = None, + verbose_metrics: bool = False, + initializer: InitializerApplicator = InitializerApplicator(), + regularizer: Optional[RegularizerApplicator] = None) -> None: + super().__init__(vocab, regularizer) + + self.context_for_intent = context_for_intent + self.context_for_tag = context_for_tag + self.attention_for_intent = attention_for_intent + self.attention_for_tag = attention_for_tag + self.sequence_label_namespace = sequence_label_namespace + self.intent_label_namespace = intent_label_namespace + self.text_field_embedder = text_field_embedder + self.num_tags = self.vocab.get_vocab_size(sequence_label_namespace) + self.num_intents = self.vocab.get_vocab_size(intent_label_namespace) + self.encoder = encoder + self.intent_encoder = intent_encoder + self.tag_encoder = intent_encoder + self._feedforward = feedforward + self._verbose_metrics = verbose_metrics + self.rl = False + + if attention: + if attention_function: + raise ConfigurationError("You can only specify an attention module or an " + "attention function, but not both.") + self.attention = attention + elif attention_function: + self.attention = LegacyAttention(attention_function) + + if dropout: + self.dropout = torch.nn.Dropout(dropout) + else: + self.dropout = None + + projection_input_dim = feedforward.get_output_dim() if self._feedforward else self.encoder.get_output_dim() + if self.context_for_intent: + projection_input_dim += self.encoder.get_output_dim() + if self.attention_for_intent: + projection_input_dim += self.encoder.get_output_dim() + self.intent_projection_layer = Linear(projection_input_dim, self.num_intents) + + if num_train_examples: + try: + pos_weight = torch.tensor([log10((num_train_examples - self.vocab._retained_counter[intent_label_namespace][t]) / + self.vocab._retained_counter[intent_label_namespace][t]) for i, t in + self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()]) + except: + pos_weight = torch.tensor([1. for i, t in + self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()]) + else: + # pos_weight = torch.tensor([(lambda t: 1. if "general" in t else nongeneral_intent_weight)(t) for i, t in + pos_weight = torch.tensor([(lambda t: nongeneral_intent_weight if "Request" in t else 1.)(t) for i, t in + self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()]) + self.intent_loss = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="none") + + tag_projection_input_dim = feedforward.get_output_dim() if self._feedforward else self.encoder.get_output_dim() + if self.context_for_tag: + tag_projection_input_dim += self.encoder.get_output_dim() + if self.attention_for_tag: + tag_projection_input_dim += self.encoder.get_output_dim() + self.tag_projection_layer = TimeDistributed(Linear(tag_projection_input_dim, + self.num_tags)) + + # if constrain_crf_decoding and calculate_span_f1 are not + # provided, (i.e., they're None), set them to True + # if label_encoding is provided and False if it isn't. + if constrain_crf_decoding is None: + constrain_crf_decoding = label_encoding is not None + if calculate_span_f1 is None: + calculate_span_f1 = label_encoding is not None + + self.label_encoding = label_encoding + if constrain_crf_decoding: + if not label_encoding: + raise ConfigurationError("constrain_crf_decoding is True, but " + "no label_encoding was specified.") + labels = self.vocab.get_index_to_token_vocabulary(sequence_label_namespace) + constraints = allowed_transitions(label_encoding, labels) + else: + constraints = None + + self.include_start_end_transitions = include_start_end_transitions + if crf_decoding: + self.crf = ConditionalRandomField( + self.num_tags, constraints, + include_start_end_transitions=include_start_end_transitions + ) + else: + self.crf = None + + self._intent_f1_metric = MultiLabelF1Measure(vocab, + namespace=intent_label_namespace) + self.calculate_span_f1 = calculate_span_f1 + if calculate_span_f1: + if not label_encoding: + raise ConfigurationError("calculate_span_f1 is True, but " + "no label_encoding was specified.") + self._f1_metric = SpanBasedF1Measure(vocab, + tag_namespace=sequence_label_namespace, + label_encoding=label_encoding) + self._dai_f1_metric = DialogActItemF1Measure() + + check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), + "text field embedding dim", "encoder input dim") + if feedforward is not None: + check_dimensions_match(encoder.get_output_dim(), feedforward.get_input_dim(), + "encoder output dim", "feedforward input dim") + initializer(self) + +
[docs] @overrides + def forward(self, # type: ignore + context_tokens: Dict[str, torch.LongTensor], + tokens: Dict[str, torch.LongTensor], + tags: torch.LongTensor = None, + intents: torch.LongTensor = None, + metadata: List[Dict[str, Any]] = None, + # pylint: disable=unused-argument + **kwargs) -> Dict[str, torch.Tensor]: + # pylint: disable=arguments-differ + """ + Parameters + ---------- + + Returns + ------- + """ + if self.context_for_intent or self.context_for_tag or \ + self.attention_for_intent or self.attention_for_tag: + embedded_context_input = self.text_field_embedder(context_tokens) + + if self.dropout: + embedded_context_input = self.dropout(embedded_context_input) + + context_mask = util.get_text_field_mask(context_tokens) + encoded_context = self.encoder(embedded_context_input, context_mask) + + if self.dropout: + encoded_context = self.dropout(encoded_context) + + encoded_context_summary = util.get_final_encoder_states( + encoded_context, + context_mask, + self.encoder.is_bidirectional()) + + embedded_text_input = self.text_field_embedder(tokens) + mask = util.get_text_field_mask(tokens) + + if self.dropout: + embedded_text_input = self.dropout(embedded_text_input) + + encoded_text = self.encoder(embedded_text_input, mask) + + if self.dropout: + encoded_text = self.dropout(encoded_text) + + intent_encoded_text = self.intent_encoder(encoded_text, mask) if self.intent_encoder else encoded_text + + if self.dropout and self.intent_encoder: + intent_encoded_text = self.dropout(intent_encoded_text) + + is_bidirectional = self.intent_encoder.is_bidirectional() if self.intent_encoder else self.encoder.is_bidirectional() + if self._feedforward is not None: + encoded_summary = self._feedforward(util.get_final_encoder_states( + intent_encoded_text, + mask, + is_bidirectional)) + else: + encoded_summary = util.get_final_encoder_states( + intent_encoded_text, + mask, + is_bidirectional) + + tag_encoded_text = self.tag_encoder(encoded_text, mask) if self.tag_encoder else encoded_text + + if self.dropout and self.tag_encoder: + tag_encoded_text = self.dropout(tag_encoded_text) + + if self.attention_for_intent or self.attention_for_tag: + attention_weights = self.attention(encoded_summary, encoded_context, context_mask.float()) + attended_context = util.weighted_sum(encoded_context, attention_weights) + + if self.context_for_intent: + encoded_summary = torch.cat([encoded_summary, encoded_context_summary], dim=-1) + + if self.attention_for_intent: + encoded_summary = torch.cat([encoded_summary, attended_context], dim=-1) + + if self.context_for_tag: + tag_encoded_text = torch.cat([tag_encoded_text, + encoded_context_summary.unsqueeze(dim=1).expand( + encoded_context_summary.size(0), + tag_encoded_text.size(1), + encoded_context_summary.size(1))], dim=-1) + + if self.attention_for_tag: + tag_encoded_text = torch.cat([tag_encoded_text, + attended_context.unsqueeze(dim=1).expand( + attended_context.size(0), + tag_encoded_text.size(1), + attended_context.size(1))], dim=-1) + + intent_logits = self.intent_projection_layer(encoded_summary) + intent_probs = torch.sigmoid(intent_logits) + predicted_intents = (intent_probs > 0.5).long() + + sequence_logits = self.tag_projection_layer(tag_encoded_text) + if self.crf is not None: + best_paths = self.crf.viterbi_tags(sequence_logits, mask) + # Just get the tags and ignore the score. + predicted_tags = [x for x, y in best_paths] + else: + predicted_tags = self.get_predicted_tags(sequence_logits) + + output = {"sequence_logits": sequence_logits, "mask": mask, "tags": predicted_tags, + "intent_logits": intent_logits, "intent_probs": intent_probs, "intents": predicted_intents} + + if tags is not None: + if self.crf is not None: + # Add negative log-likelihood as loss + log_likelihood = self.crf(sequence_logits, tags, mask) + output["loss"] = -log_likelihood + + # Represent viterbi tags as "class probabilities" that we can + # feed into the metrics + class_probabilities = sequence_logits * 0. + for i, instance_tags in enumerate(predicted_tags): + for j, tag_id in enumerate(instance_tags): + class_probabilities[i, j, tag_id] = 1 + else: + loss = sequence_cross_entropy_with_logits(sequence_logits, tags, mask) + class_probabilities = sequence_logits + output["loss"] = loss + + if self.calculate_span_f1: + self._f1_metric(class_probabilities, tags, mask.float()) + + if metadata is not None: + output["words"] = [x["words"] for x in metadata] + + if tags is not None and metadata: + self.decode(output) + self._dai_f1_metric(output["dialog_act"], [x["dialog_act"] for x in metadata]) + rewards = self.get_rewards(output["dialog_act"], [x["dialog_act"] for x in metadata]) if self.rl else None + + if intents is not None: + output["loss"] += torch.mean(self.intent_loss(intent_logits, intents.float())) + self._intent_f1_metric(predicted_intents, intents) + + return output
+ + +
[docs] def get_predicted_tags(self, sequence_logits: torch.Tensor) -> torch.Tensor: + """ + Does a simple position-wise argmax over each token, converts indices to string labels, and + adds a ``"tags"`` key to the dictionary with the result. + """ + all_predictions = sequence_logits + all_predictions = all_predictions.detach().cpu().numpy() + if all_predictions.ndim == 3: + predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] + else: + predictions_list = [all_predictions] + all_tags = [] + for predictions in predictions_list: + tags = np.argmax(predictions, axis=-1) + all_tags.append(tags) + return all_tags
+ + +
[docs] @overrides + def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Converts the tag ids to the actual tags. + ``output_dict["tags"]`` is a list of lists of tag_ids, + so we use an ugly nested list comprehension. + """ + output_dict["tags"] = [ + [self.vocab.get_token_from_index(tag, namespace=self.sequence_label_namespace) + for tag in instance_tags] + for instance_tags in output_dict["tags"] + ] + output_dict["intents"] = [ + [self.vocab.get_token_from_index(intent[0], namespace=self.intent_label_namespace) + for intent in instance_intents.nonzero().tolist()] + for instance_intents in output_dict["intents"] + ] + + output_dict["dialog_act"] = [] + for i, tags in enumerate(output_dict["tags"]): + seq_len = len(output_dict["words"][i]) + spans = bio_tags_to_spans(tags[:seq_len]) + dialog_act = {} + for span in spans: + domain_act = span[0].split("+")[0] + slot = span[0].split("+")[1] + value = " ".join(output_dict["words"][i][span[1][0]:span[1][1]+1]) + if domain_act not in dialog_act: + dialog_act[domain_act] = [[slot, value]] + else: + dialog_act[domain_act].append([slot, value]) + for intent in output_dict["intents"][i]: + if "+" in intent: + if "*" in intent: + intent, value = intent.split("*", 1) + else: + value = "?" + domain_act = intent.split("+")[0] + if domain_act not in dialog_act: + dialog_act[domain_act] = [[intent.split("+")[1], value]] + else: + dialog_act[domain_act].append([intent.split("+")[1], value]) + else: + dialog_act[intent] = [["none", "none"]] + output_dict["dialog_act"].append(dialog_act) + + return output_dict
+ + +
[docs] @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + metrics_to_return = {} + intent_f1_dict = self._intent_f1_metric.get_metric(reset=reset) + metrics_to_return.update({"int_"+x[:1]: y for x, y in intent_f1_dict.items() if "overall" in x}) + if self.calculate_span_f1: + f1_dict = self._f1_metric.get_metric(reset=reset) + metrics_to_return.update({"tag_"+x[:1]: y for x, y in f1_dict.items() if "overall" in x}) + metrics_to_return.update(self._dai_f1_metric.get_metric(reset=reset)) + return metrics_to_return
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/multilabel_f1_measure.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/multilabel_f1_measure.html new file mode 100644 index 0000000..311785f --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/multilabel_f1_measure.html @@ -0,0 +1,318 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.milu.multilabel_f1_measure — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.milu.multilabel_f1_measure
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.milu.multilabel_f1_measure

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from collections import defaultdict
+from typing import Dict, List, Optional, Set
+
+import torch
+from allennlp.data.vocabulary import Vocabulary
+from allennlp.training.metrics.metric import Metric
+
+
+
[docs]@Metric.register("multilabel_f1") +class MultiLabelF1Measure(Metric): + """ + """ + def __init__(self, + vocabulary: Vocabulary, + namespace: str = "intent_labels", + ignore_classes: List[str] = None, + coarse: bool = True) -> None: + """ + Parameters + ---------- + vocabulary : ``Vocabulary``, required. + A vocabulary containing the label namespace. + namespace : str, required. + The vocabulary namespace for labels. + ignore_classes : List[str], optional. + Labels which will be ignored when computing metrics. + """ + self._label_vocabulary = vocabulary.get_index_to_token_vocabulary(namespace) + self._ignore_classes: List[str] = ignore_classes or [] + self._coarse = coarse + + # These will hold per label span counts. + self._true_positives: Dict[str, int] = defaultdict(int) + self._false_positives: Dict[str, int] = defaultdict(int) + self._false_negatives: Dict[str, int] = defaultdict(int) + + def __call__(self, + predictions: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.Tensor] = None): + """ + Parameters + ---------- + predictions : ``torch.Tensor``, required. + A tensor of predictions of shape (batch_size, sequence_length, num_classes). + gold_labels : ``torch.Tensor``, required. + A tensor of integer class label of shape (batch_size, sequence_length). It must be the same + shape as the ``predictions`` tensor without the ``num_classes`` dimension. + mask: ``torch.Tensor``, optional (default = None). + A masking tensor the same size as ``gold_labels``. + """ + if mask is None: + mask = torch.ones_like(gold_labels) + + predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask) + + if self._coarse: + num_positives = predictions.sum() + num_false_positives = ((predictions - gold_labels) > 0).long().sum() + self._false_positives["coarse_overall"] += num_false_positives + num_true_positives = num_positives - num_false_positives + self._true_positives["coarse_overall"] += num_true_positives + num_false_negatives = ((gold_labels - predictions) > 0).long().sum() + self._false_negatives["coarse_overall"] += num_false_negatives + else: + # Iterate over timesteps in batch. + batch_size = gold_labels.size(0) + for i in range(batch_size): + prediction = predictions[i, :] + gold_label = gold_labels[i, :] + for label_id in range(gold_label.size(-1)): + label = self._label_vocabulary[label_id] + if prediction[label_id] == 1 and gold_label[label_id] == 1: + self._true_positives[label] += 1 + elif prediction[label_id] == 1 and gold_label[label_id] == 0: + self._false_positives[label] += 1 + elif prediction[label_id] == 0 and gold_label[label_id] == 1: + self._false_negatives[label] += 1 + + +
[docs] def get_metric(self, reset: bool = False): + """ + Returns + ------- + A Dict per label containing following the span based metrics: + precision : float + recall : float + f1-measure : float + + Additionally, an ``overall`` key is included, which provides the precision, + recall and f1-measure for all spans. + """ + all_labels: Set[str] = set() + all_labels.update(self._true_positives.keys()) + all_labels.update(self._false_positives.keys()) + all_labels.update(self._false_negatives.keys()) + all_metrics = {} + for label in all_labels: + precision, recall, f1_measure = self._compute_metrics(self._true_positives[label], + self._false_positives[label], + self._false_negatives[label]) + precision_key = "precision" + "-" + label + recall_key = "recall" + "-" + label + f1_key = "f1-measure" + "-" + label + all_metrics[precision_key] = precision + all_metrics[recall_key] = recall + all_metrics[f1_key] = f1_measure + + # Compute the precision, recall and f1 for all spans jointly. + precision, recall, f1_measure = self._compute_metrics(sum(self._true_positives.values()), + sum(self._false_positives.values()), + sum(self._false_negatives.values())) + all_metrics["precision-overall"] = precision + all_metrics["recall-overall"] = recall + all_metrics["f1-measure-overall"] = f1_measure + if reset: + self.reset() + return all_metrics
+ + @staticmethod + def _compute_metrics(true_positives: int, false_positives: int, false_negatives: int): + precision = float(true_positives) / float(true_positives + false_positives) if true_positives + false_positives > 0 else 0 + recall = float(true_positives) / float(true_positives + false_negatives)if true_positives + false_negatives > 0 else 0 + f1_measure = 2. * ((precision * recall) / (precision + recall)) if precision + recall > 0 else 0 + return precision, recall, f1_measure + +
[docs] def reset(self): + self._true_positives = defaultdict(int) + self._false_positives = defaultdict(int) + self._false_negatives = defaultdict(int)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/nlu.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/nlu.html new file mode 100644 index 0000000..8a05559 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/nlu.html @@ -0,0 +1,313 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.milu.nlu — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.milu.nlu
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.milu.nlu

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+
+import os
+from pprint import pprint
+
+from allennlp.common.checks import check_for_gpu
+from allennlp.data import DatasetReader
+from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
+from allennlp.models.archival import load_archive
+
+from convlab.lib.file_util import cached_path
+from convlab.modules.nlu.nlu import NLU
+from convlab.modules.nlu.multiwoz.milu import dataset_reader, model 
+
+DEFAULT_CUDA_DEVICE = -1
+DEFAULT_DIRECTORY = "models"
+DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "milu.tar.gz")
+
+
[docs]class MILU(NLU): + """Multi-intent language understanding model.""" + + def __init__(self, + archive_file=DEFAULT_ARCHIVE_FILE, + cuda_device=DEFAULT_CUDA_DEVICE, + model_file=None, + context_size=3): + """ Constructor for NLU class. """ + + self.context_size = context_size + + check_for_gpu(cuda_device) + + if not os.path.isfile(archive_file): + if not model_file: + raise Exception("No model for MILU is specified!") + + archive_file = cached_path(model_file) + + archive = load_archive(archive_file, + cuda_device=cuda_device) + self.tokenizer = SpacyWordSplitter(language="en_core_web_sm") + dataset_reader_params = archive.config["dataset_reader"] + self.dataset_reader = DatasetReader.from_params(dataset_reader_params) + self.model = archive.model + self.model.eval() + + +
[docs] def parse(self, utterance, context=[]): + """ + Predict the dialog act of a natural language utterance and apply error model. + Args: + utterance (str): A natural language utterance. + Returns: + output (dict): The dialog act of utterance. + """ + if len(utterance) == 0: + return {} + + if self.context_size > 0 and len(context) > 0: + context_tokens = sum([self.tokenizer.split_words(utterance+" SENT_END") for utterance in context[-self.context_size:]], []) + else: + context_tokens = self.tokenizer.split_words("SENT_END") + tokens = self.tokenizer.split_words(utterance) + instance = self.dataset_reader.text_to_instance(context_tokens, tokens) + outputs = self.model.forward_on_instance(instance) + + return outputs["dialog_act"]
+ + +if __name__ == "__main__": + nlu = MILU() + test_contexts = [ + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + "SENT_END", + ] + test_utterances = [ + "What type of accommodations are they. No , i just need their address . Can you tell me if the hotel has internet available ?", + "What type of accommodations are they.", + "No , i just need their address .", + "Can you tell me if the hotel has internet available ?", + "you're welcome! enjoy your visit! goodbye.", + "yes. it should be moderately priced.", + "i want to book a table for 6 at 18:45 on thursday", + "i will be departing out of stevenage.", + "What is the Name of attraction ?", + "Can I get the name of restaurant?", + "Can I get the address and phone number of the restaurant?", + "do you have a specific area you want to stay in?" + ] + for ctxt, utt in zip(test_contexts, test_utterances): + print(ctxt) + print(utt) + pprint(nlu.parse(utt)) + # pprint(nlu.parse(utt.lower())) + + test_contexts = [ + "The phone number of the hotel is 12345678", + "I have many that meet your requests", + "The phone number of the hotel is 12345678", + "I found one hotel room", + "thank you", + "Is it moderately priced?", + "Can I help you with booking?", + "Where are you departing from?", + "I found an attraction", + "I found a restaurant", + "I found a restaurant", + "I'm looking for a place to stay.", + ] + for ctxt, utt in zip(test_contexts, test_utterances): + print(ctxt) + print(utt) + pprint(nlu.parse(utt, [ctxt])) + # pprint(nlu.parse(utt.lower(), ctxt.lower())) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/train.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/train.html new file mode 100644 index 0000000..9a542d4 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/milu/train.html @@ -0,0 +1,385 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.milu.train — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.milu.train
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.milu.train

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+The ``train`` subcommand can be used to train a model.
+It requires a configuration file and a directory in
+which to write the results.
+"""
+
+import argparse
+import logging
+import os
+
+from allennlp.common import Params
+from allennlp.common.checks import check_for_gpu
+from allennlp.common.util import prepare_environment, prepare_global_logging, cleanup_global_logging, dump_metrics
+from allennlp.models.archival import archive_model, CONFIG_NAME
+from allennlp.models.model import Model, _DEFAULT_WEIGHTS
+from allennlp.training.trainer import Trainer, TrainerPieces
+from allennlp.training.trainer_base import TrainerBase
+from allennlp.training.util import create_serialization_dir, evaluate
+
+from convlab.modules.nlu.multiwoz.milu import dataset_reader, model 
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+argparser = argparse.ArgumentParser(description="Train a model.")
+argparser.add_argument('param_path',
+                        type=str,
+                        help='path to parameter file describing the model to be trained')
+argparser.add_argument('-s', '--serialization-dir',
+                        required=True,
+                        type=str,
+                        help='directory in which to save the model and its logs')
+argparser.add_argument('-r', '--recover',
+                        action='store_true',
+                        default=False,
+                        help='recover training from the state in serialization_dir')
+argparser.add_argument('-f', '--force',
+                        action='store_true',
+                        required=False,
+                        help='overwrite the output directory if it exists')
+argparser.add_argument('-o', '--overrides',
+                        type=str,
+                        default="",
+                        help='a JSON structure used to override the experiment configuration')
+argparser.add_argument('--file-friendly-logging',
+                        action='store_true',
+                        default=False,
+                        help='outputs tqdm status on separate lines and slows tqdm refresh rate')
+
+
+
+
[docs]def train_model_from_args(args: argparse.Namespace): + """ + Just converts from an ``argparse.Namespace`` object to string paths. + """ + train_model_from_file(args.param_path, + args.serialization_dir, + args.overrides, + args.file_friendly_logging, + args.recover, + args.force)
+ + +
[docs]def train_model_from_file(parameter_filename: str, + serialization_dir: str, + overrides: str = "", + file_friendly_logging: bool = False, + recover: bool = False, + force: bool = False) -> Model: + """ + A wrapper around :func:`train_model` which loads the params from a file. + + Parameters + ---------- + parameter_filename : ``str`` + A json parameter file specifying an AllenNLP experiment. + serialization_dir : ``str`` + The directory in which to save results and logs. We just pass this along to + :func:`train_model`. + overrides : ``str`` + A JSON string that we will use to override values in the input parameter file. + file_friendly_logging : ``bool``, optional (default=False) + If ``True``, we make our output more friendly to saved model files. We just pass this + along to :func:`train_model`. + recover : ``bool`, optional (default=False) + If ``True``, we will try to recover a training run from an existing serialization + directory. This is only intended for use when something actually crashed during the middle + of a run. For continuing training a model on new data, see the ``fine-tune`` command. + force : ``bool``, optional (default=False) + If ``True``, we will overwrite the serialization directory if it already exists. + """ + # Load the experiment config from a file and pass it to ``train_model``. + params = Params.from_file(parameter_filename, overrides) + return train_model(params, serialization_dir, file_friendly_logging, recover, force)
+ + +
[docs]def train_model(params: Params, + serialization_dir: str, + file_friendly_logging: bool = False, + recover: bool = False, + force: bool = False) -> Model: + """ + Trains the model specified in the given :class:`Params` object, using the data and training + parameters also specified in that object, and saves the results in ``serialization_dir``. + + Parameters + ---------- + params : ``Params`` + A parameter object specifying an AllenNLP Experiment. + serialization_dir : ``str`` + The directory in which to save results and logs. + file_friendly_logging : ``bool``, optional (default=False) + If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow + down tqdm's output to only once every 10 seconds. + recover : ``bool``, optional (default=False) + If ``True``, we will try to recover a training run from an existing serialization + directory. This is only intended for use when something actually crashed during the middle + of a run. For continuing training a model on new data, see the ``fine-tune`` command. + force : ``bool``, optional (default=False) + If ``True``, we will overwrite the serialization directory if it already exists. + + Returns + ------- + best_model: ``Model`` + The model with the best epoch weights. + """ + prepare_environment(params) + create_serialization_dir(params, serialization_dir, recover, force) + stdout_handler = prepare_global_logging(serialization_dir, file_friendly_logging) + + cuda_device = params.params.get('trainer').get('cuda_device', -1) + check_for_gpu(cuda_device) + + params.to_file(os.path.join(serialization_dir, CONFIG_NAME)) + + evaluate_on_test = params.pop_bool("evaluate_on_test", False) + + trainer_type = params.get("trainer", {}).get("type", "default") + + if trainer_type == "default": + # Special logic to instantiate backward-compatible trainer. + pieces = TrainerPieces.from_params(params, serialization_dir, recover) # pylint: disable=no-member + trainer = Trainer.from_params( + model=pieces.model, + serialization_dir=serialization_dir, + iterator=pieces.iterator, + train_data=pieces.train_dataset, + validation_data=pieces.validation_dataset, + params=pieces.params, + validation_iterator=pieces.validation_iterator) + evaluation_iterator = pieces.validation_iterator or pieces.iterator + evaluation_dataset = pieces.test_dataset + + else: + trainer = TrainerBase.from_params(params, serialization_dir, recover) + # TODO(joelgrus): handle evaluation in the general case + evaluation_iterator = evaluation_dataset = None + + params.assert_empty('base train command') + + try: + metrics = trainer.train() + except KeyboardInterrupt: + # if we have completed an epoch, try to create a model archive. + if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): + logging.info("Training interrupted by the user. Attempting to create " + "a model archive using the current best epoch weights.") + archive_model(serialization_dir, files_to_archive=params.files_to_archive) + raise + + # Evaluate + if evaluation_dataset and evaluate_on_test: + logger.info("The model will be evaluated using the best epoch weights.") + test_metrics = evaluate(trainer.model, evaluation_dataset, evaluation_iterator, + cuda_device=trainer._cuda_devices[0], # pylint: disable=protected-access, + # TODO(brendanr): Pass in an arg following Joel's trainer refactor. + batch_weight_key="") + + for key, value in test_metrics.items(): + metrics["test_" + key] = value + + elif evaluation_dataset: + logger.info("To evaluate on the test set after training, pass the " + "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.") + + cleanup_global_logging(stdout_handler) + + # Now tar up results + archive_model(serialization_dir, files_to_archive=params.files_to_archive) + dump_metrics(os.path.join(serialization_dir, "metrics.json"), metrics, log=True) + + # We count on the trainer to have the model with best weights + return trainer.model
+ +if __name__ == "__main__": + args = argparser.parse_args() + train_model_from_args(args) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/dai_f1_measure.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/dai_f1_measure.html new file mode 100644 index 0000000..16a05c5 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/dai_f1_measure.html @@ -0,0 +1,269 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.onenet.dai_f1_measure — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.onenet.dai_f1_measure
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.onenet.dai_f1_measure

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from typing import Dict, List, Any
+
+from allennlp.training.metrics.metric import Metric
+
+
+
[docs]class DialogActItemF1Measure(Metric): + """ + """ + def __init__(self) -> None: + """ + Parameters + ---------- + """ + # These will hold per label span counts. + self._true_positives = 0 + self._false_positives = 0 + self._false_negatives = 0 + + + def __call__(self, + predictions: List[Dict[str, Any]], + gold_labels: List[Dict[str, Any]]): + """ + Parameters + ---------- + predictions : ``torch.Tensor``, required. + A tensor of predictions of shape (batch_size, sequence_length, num_classes). + gold_labels : ``torch.Tensor``, required. + A tensor of integer class label of shape (batch_size, sequence_length). It must be the same + shape as the ``predictions`` tensor without the ``num_classes`` dimension. + """ + for prediction, gold_label in zip(predictions, gold_labels): + for dat in prediction: + for sv in prediction[dat]: + if dat not in gold_label or sv not in gold_label[dat]: + self._false_positives += 1 + else: + self._true_positives += 1 + for dat in gold_label: + for sv in gold_label[dat]: + if dat not in prediction or sv not in prediction[dat]: + self._false_negatives += 1 + + +
[docs] def get_metric(self, reset: bool = False): + """ + Returns + ------- + A Dict per label containing following the span based metrics: + precision : float + recall : float + f1-measure : float + + Additionally, an ``overall`` key is included, which provides the precision, + recall and f1-measure for all spans. + """ + # Compute the precision, recall and f1 for all spans jointly. + precision, recall, f1_measure = self._compute_metrics(self._true_positives, + self._false_positives, + self._false_negatives) + metrics = {} + metrics["precision"] = precision + metrics["recall"] = recall + metrics["f1-measure"] = f1_measure + if reset: + self.reset() + return metrics
+ + + @staticmethod + def _compute_metrics(true_positives: int, false_positives: int, false_negatives: int): + precision = float(true_positives) / float(true_positives + false_positives + 1e-13) + recall = float(true_positives) / float(true_positives + false_negatives + 1e-13) + f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13)) + return precision, recall, f1_measure + + +
[docs] def reset(self): + self._true_positives = 0 + self._false_positives = 0 + self._false_negatives = 0
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/dataset_reader.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/dataset_reader.html new file mode 100644 index 0000000..d887d5c --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/dataset_reader.html @@ -0,0 +1,320 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.onenet.dataset_reader — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.onenet.dataset_reader
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.onenet.dataset_reader

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import json
+import logging
+import os
+import zipfile
+from typing import Dict, List, Any
+
+from allennlp.data.dataset_readers.dataset_reader import DatasetReader
+from allennlp.data.fields import TextField, SequenceLabelField, LabelField, MetadataField, Field
+from allennlp.data.instance import Instance
+from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
+from allennlp.data.tokenizers import Token
+from overrides import overrides
+
+from convlab.lib.file_util import cached_path
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+
[docs]@DatasetReader.register("onenet") +class OneNetDatasetReader(DatasetReader): + """ + Reads instances from a pretokenised file where each line + and converts it into a ``Dataset`` suitable for sequence tagging. + + Parameters + ---------- + """ + def __init__(self, + token_delimiter: str = None, + token_indexers: Dict[str, TokenIndexer] = None, + lazy: bool = False) -> None: + super().__init__(lazy) + self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} + self._token_delimiter = token_delimiter + + @overrides + def _read(self, file_path): + # if `file_path` is a URL, redirect to the cache + file_path = cached_path(file_path) + + if file_path.endswith("zip"): + archive = zipfile.ZipFile(file_path, "r") + data_file = archive.open(os.path.basename(file_path)[:-4]) + else: + data_file = open(file_path, "r") + + logger.info("Reading instances from lines in file at: %s", file_path) + + dialogs = json.load(data_file) + + for dial_name in dialogs: + dialog = dialogs[dial_name]["log"] + for turn in dialog: + tokens = turn["text"].split() + spans = turn["span_info"] + tags = [] + domain = "None" + intent = "None" + for i in range(len(tokens)): + for span in spans: + if i == span[3]: + new_domain, new_intent = span[0].split("-", 1) + if domain == "None": + domain = new_domain + elif domain != new_domain: + continue + if intent == "None": + intent = new_intent + elif intent != new_intent: + continue + tags.append("B-"+span[1]) + break + if i > span[3] and i <= span[4]: + new_domain, new_intent = span[0].split("-", 1) + if domain != new_domain: + continue + if intent != new_intent: + continue + tags.append("I-"+span[1]) + break + else: + tags.append("O") + + if domain != "None": + assert intent != "None", "intent must not be None when domain is not None" + elif turn["dialog_act"] != {}: + assert intent == "None", "intent must be None when domain is None" + di = list(turn["dialog_act"].keys())[0] + dai = turn["dialog_act"][di][0] + domain = di.split("-")[0] + intent = di.split("-", 1)[-1] + "+" + dai[0] + "*" + dai[1] + + dialog_act = {} + for dacts in turn["span_info"]: + if dacts[0] not in dialog_act: + dialog_act[dacts[0]] = [] + dialog_act[dacts[0]].append([dacts[1], " ".join(tokens[dacts[3]: dacts[4]+1])]) + + for dacts in turn["dialog_act"]: + for dact in turn["dialog_act"][dacts]: + if dacts not in dialog_act: + dialog_act[dacts] = turn["dialog_act"][dacts] + break + elif dact[0] not in [sv[0] for sv in dialog_act[dacts]]: + dialog_act[dacts].append(dact) + + tokens = [Token(token) for token in tokens] + + yield self.text_to_instance(tokens, tags, domain, intent, dialog_act) + + +
[docs] def text_to_instance(self, tokens: List[Token], tags: List[str] = None, domain: str = None, + intent: str = None, dialog_act: Dict[str, Any] = None) -> Instance: # type: ignore + """ + We take `pre-tokenized` input here, because we don't have a tokenizer in this class. + """ + # pylint: disable=arguments-differ + fields: Dict[str, Field] = {} + sequence = TextField(tokens, self._token_indexers) + fields["tokens"] = sequence + if tags: + fields["tags"] = SequenceLabelField(tags, sequence) + if domain: + fields["domain"] = LabelField(domain, label_namespace="domain_labels") + if intent: + fields["intent"] = LabelField(intent, label_namespace="intent_labels") + if dialog_act is not None: + fields["metadata"] = MetadataField({"words": [x.text for x in tokens], + 'dialog_act': dialog_act}) + else: + fields["metadata"] = MetadataField({"words": [x.text for x in tokens], 'dialog_act': {}}) + return Instance(fields)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/evaluate.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/evaluate.html new file mode 100644 index 0000000..23837f0 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/evaluate.html @@ -0,0 +1,309 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.onenet.evaluate — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.onenet.evaluate
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.onenet.evaluate

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+The ``evaluate`` subcommand can be used to
+evaluate a trained model against a dataset
+and report any metrics calculated by the model.
+"""
+import argparse
+import json
+import logging
+from typing import Dict, Any
+
+from allennlp.common import Params
+from allennlp.common.util import prepare_environment
+from allennlp.data.dataset_readers.dataset_reader import DatasetReader
+from allennlp.data.iterators import DataIterator
+from allennlp.models.archival import load_archive
+from allennlp.training.util import evaluate
+
+from convlab.modules.nlu.multiwoz.onenet import dataset_reader, model 
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+argparser = argparse.ArgumentParser(description="Evaluate the specified model + dataset.")
+argparser.add_argument('archive_file', type=str, help='path to an archived trained model')
+
+argparser.add_argument('input_file', type=str, help='path to the file containing the evaluation data')
+
+argparser.add_argument('--output-file', type=str, help='path to output file')
+
+argparser.add_argument('--weights-file',
+                        type=str,
+                        help='a path that overrides which weights file to use')
+
+cuda_device = argparser.add_mutually_exclusive_group(required=False)
+cuda_device.add_argument('--cuda-device',
+                            type=int,
+                            default=-1,
+                            help='id of GPU to use (if any)')
+
+argparser.add_argument('-o', '--overrides',
+                        type=str,
+                        default="",
+                        help='a JSON structure used to override the experiment configuration')
+
+argparser.add_argument('--batch-weight-key',
+                        type=str,
+                        default="",
+                        help='If non-empty, name of metric used to weight the loss on a per-batch basis.')
+
+argparser.add_argument('--extend-vocab',
+                        action='store_true',
+                        default=False,
+                        help='if specified, we will use the instances in your new dataset to '
+                            'extend your vocabulary. If pretrained-file was used to initialize '
+                            'embedding layers, you may also need to pass --embedding-sources-mapping.')
+
+argparser.add_argument('--embedding-sources-mapping',
+                        type=str,
+                        default="",
+                        help='a JSON dict defining mapping from embedding module path to embedding'
+                        'pretrained-file used during training. If not passed, and embedding needs to be '
+                        'extended, we will try to use the original file paths used during training. If '
+                        'they are not available we will use random vectors for embedding extension.')
+
+
+
[docs]def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: + # Disable some of the more verbose logging statements + logging.getLogger('allennlp.common.params').disabled = True + logging.getLogger('allennlp.nn.initializers').disabled = True + logging.getLogger('allennlp.modules.token_embedders.embedding').setLevel(logging.INFO) + + # Load from archive + archive = load_archive(args.archive_file, args.cuda_device, args.overrides, args.weights_file) + config = archive.config + prepare_environment(config) + model = archive.model + model.eval() + + # Load the evaluation data + + # Try to use the validation dataset reader if there is one - otherwise fall back + # to the default dataset_reader used for both training and validation. + validation_dataset_reader_params = config.pop('validation_dataset_reader', None) + if validation_dataset_reader_params is not None: + dataset_reader = DatasetReader.from_params(validation_dataset_reader_params) + else: + dataset_reader = DatasetReader.from_params(config.pop('dataset_reader')) + evaluation_data_path = args.input_file + logger.info("Reading evaluation data from %s", evaluation_data_path) + instances = dataset_reader.read(evaluation_data_path) + + embedding_sources: Dict[str, str] = (json.loads(args.embedding_sources_mapping) + if args.embedding_sources_mapping else {}) + if args.extend_vocab: + logger.info("Vocabulary is being extended with test instances.") + model.vocab.extend_from_instances(Params({}), instances=instances) + model.extend_embedder_vocab(embedding_sources) + + iterator_params = config.pop("validation_iterator", None) + if iterator_params is None: + iterator_params = config.pop("iterator") + iterator = DataIterator.from_params(iterator_params) + iterator.index_with(model.vocab) + + metrics = evaluate(model, instances, iterator, args.cuda_device, args.batch_weight_key) + + logger.info("Finished evaluating.") + logger.info("Metrics:") + for key, metric in metrics.items(): + logger.info("%s: %s", key, metric) + + output_file = args.output_file + if output_file: + with open(output_file, "w") as file: + json.dump(metrics, file, indent=4) + return metrics
+ + +if __name__ == "__main__": + args = argparser.parse_args() + evaluate_from_args(args) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/model.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/model.html new file mode 100644 index 0000000..378dcf9 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/model.html @@ -0,0 +1,446 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.onenet.model — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.onenet.model
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.onenet.model

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from typing import Dict, Optional, List, Any
+
+import allennlp.nn.util as util
+import numpy as np
+import torch
+import torch.nn.functional as F
+from allennlp.common.checks import check_dimensions_match, ConfigurationError
+from allennlp.data import Vocabulary
+from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans
+from allennlp.models.model import Model
+from allennlp.modules import ConditionalRandomField, FeedForward
+from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
+from allennlp.modules.conditional_random_field import allowed_transitions
+from allennlp.nn import InitializerApplicator, RegularizerApplicator
+from allennlp.nn.util import sequence_cross_entropy_with_logits
+from overrides import overrides
+from torch.nn.modules.linear import Linear
+
+from convlab.modules.nlu.multiwoz.onenet.dai_f1_measure import DialogActItemF1Measure
+
+
+
[docs]@Model.register("onenet") +class OneNet(Model): + """ + Parameters + ---------- + initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) + Used to initialize the model parameters. + regularizer : ``RegularizerApplicator``, optional (default=``None``) + If provided, will be used to calculate the regularization penalty during training. + """ + + def __init__(self, vocab: Vocabulary, + text_field_embedder: TextFieldEmbedder, + encoder: Seq2SeqEncoder, + tag_label_namespace: str = "labels", + domain_label_namespace: str = "domain_labels", + intent_label_namespace: str = "intent_labels", + feedforward: Optional[FeedForward] = None, + label_encoding: Optional[str] = None, + include_start_end_transitions: bool = True, + crf_decoding: bool = False, + constrain_crf_decoding: bool = None, + focal_loss_gamma: float = None, + calculate_span_f1: bool = None, + dropout: Optional[float] = None, + verbose_metrics: bool = False, + initializer: InitializerApplicator = InitializerApplicator(), + regularizer: Optional[RegularizerApplicator] = None) -> None: + super().__init__(vocab, regularizer) + + self.tag_label_namespace = tag_label_namespace + self.intent_label_namespace = intent_label_namespace + self.text_field_embedder = text_field_embedder + self.num_tags = self.vocab.get_vocab_size(tag_label_namespace) + self.num_domains = self.vocab.get_vocab_size(domain_label_namespace) + self.num_intents = self.vocab.get_vocab_size(intent_label_namespace) + self.encoder = encoder + self._verbose_metrics = verbose_metrics + if dropout: + self.dropout = torch.nn.Dropout(dropout) + else: + self.dropout = None + self._feedforward = feedforward + + self.tag_projection_layer = TimeDistributed(Linear(self.encoder.get_output_dim(), + self.num_tags)) + + if self._feedforward is not None: + self.domain_projection_layer = Linear(feedforward.get_output_dim(), self.num_domains) + self.intent_projection_layer = Linear(feedforward.get_output_dim(), self.num_intents) + else: + self.domain_projection_layer = Linear(self.encoder.get_output_dim(), self.num_domains) + self.intent_projection_layer = Linear(self.encoder.get_output_dim(), self.num_intents) + + self.ce_loss = torch.nn.CrossEntropyLoss() + + if constrain_crf_decoding is None: + constrain_crf_decoding = label_encoding is not None + if calculate_span_f1 is None: + calculate_span_f1 = label_encoding is not None + + self.label_encoding = label_encoding + if constrain_crf_decoding: + if not label_encoding: + raise ConfigurationError("constrain_crf_decoding is True, but " + "no label_encoding was specified.") + labels = self.vocab.get_index_to_token_vocabulary(tag_label_namespace) + constraints = allowed_transitions(label_encoding, labels) + else: + constraints = None + + self.include_start_end_transitions = include_start_end_transitions + if crf_decoding: + self.crf = ConditionalRandomField( + self.num_tags, constraints, + include_start_end_transitions=include_start_end_transitions + ) + else: + self.crf = None + + self._dai_f1_metric = DialogActItemF1Measure() + + check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), + "text field embedding dim", "encoder input dim") + if feedforward is not None: + check_dimensions_match(encoder.get_output_dim(), feedforward.get_input_dim(), + "encoder output dim", "feedforward input dim") + initializer(self) + +
[docs] @overrides + def forward(self, # type: ignore + tokens: Dict[str, torch.LongTensor], + tags: torch.LongTensor = None, + domain: torch.LongTensor = None, + intent: torch.LongTensor = None, + metadata: List[Dict[str, Any]] = None, + # pylint: disable=unused-argument + **kwargs) -> Dict[str, torch.Tensor]: + # pylint: disable=arguments-differ + """ + Parameters + ---------- + + Returns + ------- + """ + embedded_text_input = self.text_field_embedder(tokens) + mask = util.get_text_field_mask(tokens) + + if self.dropout: + embedded_text_input = self.dropout(embedded_text_input) + + encoded_text = self.encoder(embedded_text_input, mask) + + if self.dropout: + encoded_text = self.dropout(encoded_text) + + if self._feedforward is not None: + encoded_summary = self._feedforward(util.get_final_encoder_states( + encoded_text, + mask, + self.encoder.is_bidirectional())) + else: + encoded_summary = util.get_final_encoder_states( + encoded_text, + mask, + self.encoder.is_bidirectional()) + + tag_logits = self.tag_projection_layer(encoded_text) + if self.crf: + best_paths = self.crf.viterbi_tags(tag_logits, mask) + # Just get the tags and ignore the score. + predicted_tags = [x for x, y in best_paths] + else: + predicted_tags = self.get_predicted_tags(tag_logits) + + domain_logits = self.domain_projection_layer(encoded_summary) + domain_probs = F.softmax(domain_logits, dim=-1) + + intent_logits = self.intent_projection_layer(encoded_summary) + intent_probs = F.softmax(intent_logits, dim=-1) + + output = {"tag_logits": tag_logits, "mask": mask, "tags": predicted_tags, + "domain_probs": domain_probs, "intent_probs": intent_probs} + + if tags is not None: + if self.crf: + # Add negative log-likelihood as loss + log_likelihood = self.crf(tag_logits, tags, mask) + output["loss"] = -log_likelihood + + # Represent viterbi tags as "class probabilities" that we can + # feed into the metrics + class_probabilities = tag_logits * 0. + for i, instance_tags in enumerate(predicted_tags): + for j, tag_id in enumerate(instance_tags): + class_probabilities[i, j, tag_id] = 1 + else: + loss = sequence_cross_entropy_with_logits(tag_logits, tags, mask) + class_probabilities = tag_logits + output["loss"] = loss + + if domain is not None: + output["loss"] += self.ce_loss(domain_logits, domain) + if intent is not None: + output["loss"] += self.ce_loss(intent_logits, intent) + + if metadata: + output["words"] = [x["words"] for x in metadata] + + if tags is not None and metadata: + self.decode(output) + self._dai_f1_metric(output["dialog_act"], [x["dialog_act"] for x in metadata]) + + return output
+ + +
[docs] def get_predicted_tags(self, sequence_logits: torch.Tensor) -> torch.Tensor: + """ + Does a simple position-wise argmax over each token, converts indices to string labels, and + adds a ``"tags"`` key to the dictionary with the result. + """ + all_predictions = sequence_logits + all_predictions = all_predictions.detach().cpu().numpy() + if all_predictions.ndim == 3: + predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] + else: + predictions_list = [all_predictions] + all_tags = [] + for predictions in predictions_list: + tags = np.argmax(predictions, axis=-1) + all_tags.append(tags) + return all_tags
+ + +
[docs] @overrides + def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Converts the tag ids to the actual tags. + ``output_dict["tags"]`` is a list of lists of tag_ids, + so we use an ugly nested list comprehension. + """ + output_dict["tags"] = [ + [self.vocab.get_token_from_index(tag, namespace=self.tag_label_namespace) + for tag in instance_tags] + for instance_tags in output_dict["tags"] + ] + + argmax_indices = np.argmax(output_dict["domain_probs"].detach().cpu().numpy(), axis=-1) + output_dict["domain"] = [self.vocab.get_token_from_index(x, namespace="domain_labels") + for x in argmax_indices] + + argmax_indices = np.argmax(output_dict["intent_probs"].detach().cpu().numpy(), axis=-1) + output_dict["intent"] = [self.vocab.get_token_from_index(x, namespace="intent_labels") + for x in argmax_indices] + + output_dict["dialog_act"] = [] + for i, domain in enumerate(output_dict["domain"]): + if "+" not in output_dict["intent"][i]: + tags = [] + seq_len = len(output_dict["words"][i]) + for span in bio_tags_to_spans(output_dict["tags"][i][:seq_len]): + tags.append([span[0], " ".join(output_dict["words"][i][span[1][0]: span[1][1]+1])]) + intent = output_dict["intent"][i] if len(tags) > 0 else "None" + else: + intent, value = output_dict["intent"][i].split("*", 1) + intent, slot = intent.split("+") + tags = [[slot, value]] + dialog_act = {domain+"-"+intent: tags} if domain != "None" and intent != "None" else {} + output_dict["dialog_act"].append(dialog_act) + + return output_dict
+ + +
[docs] @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + return self._dai_f1_metric.get_metric(reset=reset)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/nlu.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/nlu.html new file mode 100644 index 0000000..7ccf315 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/nlu.html @@ -0,0 +1,270 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.onenet.nlu — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.onenet.nlu
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.onenet.nlu

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+
+import os
+from pprint import pprint
+
+from allennlp.common.checks import check_for_gpu
+from allennlp.data import DatasetReader
+from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
+from allennlp.models.archival import load_archive
+
+from convlab.lib.file_util import cached_path
+from convlab.modules.nlu.nlu import NLU
+from convlab.modules.nlu.multiwoz.onenet import dataset_reader, model 
+
+DEFAULT_CUDA_DEVICE=-1
+DEFAULT_DIRECTORY = "models"
+DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "onenet.tar.gz")
+
+
[docs]class OneNetLU(NLU): + """Multilabel sequence tagging model.""" + + def __init__(self, + archive_file=DEFAULT_ARCHIVE_FILE, + cuda_device=DEFAULT_CUDA_DEVICE, + model_file=None): + """ Constructor for NLU class. """ + check_for_gpu(cuda_device) + + if not os.path.isfile(archive_file): + if not model_file: + raise Exception("No model for JointNLU is specified!") + archive_file = cached_path(model_file) + + + archive = load_archive(archive_file, + cuda_device=cuda_device) + self.tokenizer = SpacyWordSplitter(language="en_core_web_sm") + dataset_reader_params = archive.config["dataset_reader"] + self.dataset_reader = DatasetReader.from_params(dataset_reader_params) + self.model = archive.model + self.model.eval() + +
[docs] def parse(self, utterance, context=[]): + """ + Predict the dialog act of a natural language utterance and apply error model. + Args: + utterance (str): A natural language utterance. + Returns: + output (dict): The dialog act of utterance. + """ + # print("nlu input:") + # pprint(utterance) + + if len(utterance) == 0: + return {} + + tokens = self.tokenizer.split_words(utterance) + instance = self.dataset_reader.text_to_instance(tokens) + outputs = self.model.forward_on_instance(instance) + + return outputs["dialog_act"]
+ + +if __name__ == "__main__": + nlu = OneNetLU() + test_utterances = [ + "What type of accommodations are they. No , i just need their address . Can you tell me if the hotel has internet available ?", + "What type of accommodations are they.", + "No , i just need their address .", + "Can you tell me if the hotel has internet available ?", + "you're welcome! enjoy your visit! goodbye.", + "yes. it should be moderately priced.", + "i want to book a table for 6 at 18:45 on thursday", + "i will be departing out of stevenage.", + "What is the Name of attraction ?", + "Can I get the name of restaurant?", + "do you have a specific area you want to stay in?" + ] + for utt in test_utterances: + print(utt) + pprint(nlu.parse(utt)) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/train.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/train.html new file mode 100644 index 0000000..db34796 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/onenet/train.html @@ -0,0 +1,385 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.onenet.train — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.onenet.train
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.onenet.train

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+The ``train`` subcommand can be used to train a model.
+It requires a configuration file and a directory in
+which to write the results.
+"""
+
+import argparse
+import logging
+import os
+
+from allennlp.common import Params
+from allennlp.common.checks import check_for_gpu
+from allennlp.common.util import prepare_environment, prepare_global_logging, cleanup_global_logging, dump_metrics
+from allennlp.models.archival import archive_model, CONFIG_NAME
+from allennlp.models.model import Model, _DEFAULT_WEIGHTS
+from allennlp.training.trainer import Trainer, TrainerPieces
+from allennlp.training.trainer_base import TrainerBase
+from allennlp.training.util import create_serialization_dir, evaluate
+
+from convlab.modules.nlu.multiwoz.onenet import dataset_reader, model 
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+argparser = argparse.ArgumentParser(description="Train a model.")
+argparser.add_argument('param_path',
+                        type=str,
+                        help='path to parameter file describing the model to be trained')
+argparser.add_argument('-s', '--serialization-dir',
+                        required=True,
+                        type=str,
+                        help='directory in which to save the model and its logs')
+argparser.add_argument('-r', '--recover',
+                        action='store_true',
+                        default=False,
+                        help='recover training from the state in serialization_dir')
+argparser.add_argument('-f', '--force',
+                        action='store_true',
+                        required=False,
+                        help='overwrite the output directory if it exists')
+argparser.add_argument('-o', '--overrides',
+                        type=str,
+                        default="",
+                        help='a JSON structure used to override the experiment configuration')
+argparser.add_argument('--file-friendly-logging',
+                        action='store_true',
+                        default=False,
+                        help='outputs tqdm status on separate lines and slows tqdm refresh rate')
+
+
+
+
[docs]def train_model_from_args(args: argparse.Namespace): + """ + Just converts from an ``argparse.Namespace`` object to string paths. + """ + train_model_from_file(args.param_path, + args.serialization_dir, + args.overrides, + args.file_friendly_logging, + args.recover, + args.force)
+ + +
[docs]def train_model_from_file(parameter_filename: str, + serialization_dir: str, + overrides: str = "", + file_friendly_logging: bool = False, + recover: bool = False, + force: bool = False) -> Model: + """ + A wrapper around :func:`train_model` which loads the params from a file. + + Parameters + ---------- + parameter_filename : ``str`` + A json parameter file specifying an AllenNLP experiment. + serialization_dir : ``str`` + The directory in which to save results and logs. We just pass this along to + :func:`train_model`. + overrides : ``str`` + A JSON string that we will use to override values in the input parameter file. + file_friendly_logging : ``bool``, optional (default=False) + If ``True``, we make our output more friendly to saved model files. We just pass this + along to :func:`train_model`. + recover : ``bool`, optional (default=False) + If ``True``, we will try to recover a training run from an existing serialization + directory. This is only intended for use when something actually crashed during the middle + of a run. For continuing training a model on new data, see the ``fine-tune`` command. + force : ``bool``, optional (default=False) + If ``True``, we will overwrite the serialization directory if it already exists. + """ + # Load the experiment config from a file and pass it to ``train_model``. + params = Params.from_file(parameter_filename, overrides) + return train_model(params, serialization_dir, file_friendly_logging, recover, force)
+ + +
[docs]def train_model(params: Params, + serialization_dir: str, + file_friendly_logging: bool = False, + recover: bool = False, + force: bool = False) -> Model: + """ + Trains the model specified in the given :class:`Params` object, using the data and training + parameters also specified in that object, and saves the results in ``serialization_dir``. + + Parameters + ---------- + params : ``Params`` + A parameter object specifying an AllenNLP Experiment. + serialization_dir : ``str`` + The directory in which to save results and logs. + file_friendly_logging : ``bool``, optional (default=False) + If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow + down tqdm's output to only once every 10 seconds. + recover : ``bool``, optional (default=False) + If ``True``, we will try to recover a training run from an existing serialization + directory. This is only intended for use when something actually crashed during the middle + of a run. For continuing training a model on new data, see the ``fine-tune`` command. + force : ``bool``, optional (default=False) + If ``True``, we will overwrite the serialization directory if it already exists. + + Returns + ------- + best_model: ``Model`` + The model with the best epoch weights. + """ + prepare_environment(params) + create_serialization_dir(params, serialization_dir, recover, force) + stdout_handler = prepare_global_logging(serialization_dir, file_friendly_logging) + + cuda_device = params.params.get('trainer').get('cuda_device', -1) + check_for_gpu(cuda_device) + + params.to_file(os.path.join(serialization_dir, CONFIG_NAME)) + + evaluate_on_test = params.pop_bool("evaluate_on_test", False) + + trainer_type = params.get("trainer", {}).get("type", "default") + + if trainer_type == "default": + # Special logic to instantiate backward-compatible trainer. + pieces = TrainerPieces.from_params(params, serialization_dir, recover) # pylint: disable=no-member + trainer = Trainer.from_params( + model=pieces.model, + serialization_dir=serialization_dir, + iterator=pieces.iterator, + train_data=pieces.train_dataset, + validation_data=pieces.validation_dataset, + params=pieces.params, + validation_iterator=pieces.validation_iterator) + evaluation_iterator = pieces.validation_iterator or pieces.iterator + evaluation_dataset = pieces.test_dataset + + else: + trainer = TrainerBase.from_params(params, serialization_dir, recover) + # TODO(joelgrus): handle evaluation in the general case + evaluation_iterator = evaluation_dataset = None + + params.assert_empty('base train command') + + try: + metrics = trainer.train() + except KeyboardInterrupt: + # if we have completed an epoch, try to create a model archive. + if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): + logging.info("Training interrupted by the user. Attempting to create " + "a model archive using the current best epoch weights.") + archive_model(serialization_dir, files_to_archive=params.files_to_archive) + raise + + # Evaluate + if evaluation_dataset and evaluate_on_test: + logger.info("The model will be evaluated using the best epoch weights.") + test_metrics = evaluate(trainer.model, evaluation_dataset, evaluation_iterator, + cuda_device=trainer._cuda_devices[0], # pylint: disable=protected-access, + # TODO(brendanr): Pass in an arg following Joel's trainer refactor. + batch_weight_key="") + + for key, value in test_metrics.items(): + metrics["test_" + key] = value + + elif evaluation_dataset: + logger.info("To evaluate on the test set after training, pass the " + "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.") + + cleanup_global_logging(stdout_handler) + + # Now tar up results + archive_model(serialization_dir, files_to_archive=params.files_to_archive) + dump_metrics(os.path.join(serialization_dir, "metrics.json"), metrics, log=True) + + # We count on the trainer to have the model with best weights + return trainer.model
+ +if __name__ == "__main__": + args = argparser.parse_args() + train_model_from_args(args) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Classifier.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Classifier.html new file mode 100644 index 0000000..53bbac6 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Classifier.html @@ -0,0 +1,736 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.svm.Classifier — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.svm.Classifier
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.svm.Classifier

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import json
+import multiprocessing as mp
+import os
+import pickle
+import time
+from collections import defaultdict
+
+import numpy
+from scipy.sparse import lil_matrix
+
+from convlab.modules.nlu.multiwoz.svm import sutils, Tuples
+from convlab.modules.nlu.multiwoz.svm.Features import cnet as cnet_extractor
+
+names_to_classes = {}
+
+
[docs]def trainSVMwrapper(X,y): + model = svm.SVC(kernel='linear', C=1) + model.probability = True + # model.class_weight = 'auto' + model.fit(X, y) + return model
+ +
[docs]class classifier(object): + def __init__(self, config): + # classifier type + self.type = "svm" + if config.has_option("classifier", "type") : + self.type = config.get("classifier", "type") + + # min_examples + self.min_examples = 10 + if config.has_option("classifier", "min_examples") : + self.min_examples = int(config.get("classifier","min_examples")) + + # features + self.features = ["cnet"] + if config.has_option("classifier", "features") : + self.features = json.loads(config.get("classifier", "features")) + self.feature_extractors = [] + for feature in self.features: + self.feature_extractors.append( + sutils.import_class("convlab.modules.nlu.multiwoz.svm.Features." + feature)(config) + ) + print(self.feature_extractors) + self.tuples = Tuples.tuples(config) + self.config = config + self.cnet_extractor = cnet_extractor(config) + + # store data: + self.X = {} + self.y = {} + self.baseXs = [] + self.baseX_pointers = {} + self.fnames = {} + + # @profile +
[docs] def extractFeatures(self, dw, log_input_key="batch"): + # given a dataset walker, + # adds examples to self.X and self.y + total_calls = len(dw.session_list) + print(total_calls) + # print(dw.session_list) + self.keys = set([]) + for call_num, call in enumerate(dw) : + print('[%d/%d]' % (call_num,total_calls)) + for log_turn, label_turn in call: + if label_turn != None: + uacts = label_turn['semantics']['json'] + these_tuples = self.tuples.uactsToTuples(uacts) + # check there aren't any tuples we were not expecting: + for this_tuple in these_tuples: + if this_tuple not in self.tuples.all_tuples : + print("Warning: unexpected tuple", this_tuple) + # convert tuples to specific tuples: + these_tuples = [Tuples.generic_to_specific(tup) for tup in these_tuples] + + # which tuples would be considered (active) for this turn? + active_tuples = self.tuples.activeTuples(log_turn) + + # calculate base features that are independent of the tuple + baseX = defaultdict(float) + for feature_extractor in self.feature_extractors: + feature_name = feature_extractor.__class__.__name__ + new_feats = feature_extractor.calculate(log_turn, log_input_key=log_input_key) + # if new_feats != {}: + # print('base feat:',new_feats.keys()) + for key in new_feats: + baseX[(feature_name, key)] += new_feats[key] + self.keys.add((feature_name, key)) + self.baseXs.append(baseX) + + # print('these_tuples',these_tuples) + # print('active_tuples',active_tuples) + + for this_tuple in active_tuples: + # print(this_tuple) + if label_turn != None : + y = (Tuples.generic_to_specific(this_tuple) in these_tuples) + + X = defaultdict(float) + for feature_extractor in self.feature_extractors: + feature_name = feature_extractor.__class__.__name__ + new_feats = feature_extractor.tuple_calculate(this_tuple, log_turn, log_input_key=log_input_key) + # if new_feats!={}: + # print('tuple feat',new_feats.keys()) + for key in new_feats: + X[(feature_name, key)] += new_feats[key] + self.keys.add((feature_name, key)) + + if this_tuple not in self.X : + self.X[this_tuple] = [] + if this_tuple not in self.y : + self.y[this_tuple] = [] + if this_tuple not in self.baseX_pointers : + self.baseX_pointers[this_tuple] = [] + # if this_tuple not in self.fnames : + # self.fnames[this_tuple] = [] + + self.X[this_tuple].append(X) + if label_turn != None : + self.y[this_tuple].append(y) + + self.baseX_pointers[this_tuple].append(len(self.baseXs) - 1)
+ + # self.fnames[this_tuple].append(log_turn["input"]["audio-file"]) + + +
[docs] def extractFeatures2(self, sentinfo, log_input_key="batch"): + # given a dataset walker, + # adds examples to self.X and self.y + total_calls = 1 + self.keys = set([]) + + + # calculate base features that are independent of the tuple + baseX = defaultdict(float) + for feature_extractor in self.feature_extractors: + feature_name = feature_extractor.__class__.__name__ + new_feats = feature_extractor.calculate_sent(sentinfo, log_input_key=log_input_key) + for key in new_feats: + baseX[(feature_name, key)] += new_feats[key] + self.keys.add((feature_name, key)) + self.baseXs.append(baseX) + + for this_tuple in self.classifiers: + X = defaultdict(float) + for feature_extractor in self.feature_extractors: + feature_name = feature_extractor.__class__.__name__ + new_feats = feature_extractor.tuple_calculate(this_tuple, sentinfo, log_input_key=log_input_key) + for key in new_feats: + X[(feature_name, key)] += new_feats[key] + self.keys.add((feature_name, key)) + + if this_tuple not in self.X : + self.X[this_tuple] = [] + if this_tuple not in self.y : + self.y[this_tuple] = [] + if this_tuple not in self.baseX_pointers : + self.baseX_pointers[this_tuple] = [] + # if this_tuple not in self.fnames : + # self.fnames[this_tuple] = [] + + self.X[this_tuple].append(X) + + self.baseX_pointers[this_tuple].append(len(self.baseXs) - 1)
+ + +
[docs] def createDictionary(self): + self.dictionary = {} + for i, key in enumerate(self.keys): + self.dictionary[key] = i
+ + +
[docs] def cacheFeature(self, dw, config=None): + if config == None : + config = self.config + log_input_key = "batch" + if config.has_option("train","log_input_key") : + log_input_key = config.get("train","log_input_key") + print("extracting features from turns") + self.extractFeatures(dw, log_input_key=log_input_key) + print("finished extracting features") + print("creating feature dictionary") + self.createDictionary() + print("finished creating dictionary (of size", len(self.dictionary), ")")
+ +
[docs] def train(self, dw, config=None): + # print "creating feature dictionary" + # self.createDictionary() + # print "finished creating dictionary (of size",len(self.dictionary),")" + self.classifiers = {} + total_num = len(self.tuples.all_tuples) + cur_num = 0 + print(self.tuples.all_tuples) + print(self.X.keys()) + + pool = mp.Pool(processes=20) + res = [] + + for this_tuple in self.tuples.all_tuples: + cur_num += 1 + print("%d/%d" % (cur_num, total_num)) + print("training", this_tuple) + if this_tuple not in self.X : + print("Warning: no examples of", this_tuple) + self.classifiers[this_tuple] = None + continue + baseXs = [self.baseXs[index] for index in self.baseX_pointers[this_tuple]] + y = list(map(int, self.y[this_tuple])) + if sum(y) < self.min_examples: + print("Warning: not enough examples (%d) of" % sum(y), this_tuple) + self.classifiers[this_tuple] = None + continue + if len(set(y)) < 2: + print("Warning: only one class of", this_tuple) + self.classifiers[this_tuple] = None + continue + # print(self.X[this_tuple]) + X = toSparse(baseXs, self.X[this_tuple], self.dictionary) + + + # pick the right classifier class + self.classifiers[this_tuple] = names_to_classes[self.type](self.config) + # self.classifiers[this_tuple].train(X,y) + + result = pool.apply_async(trainSVMwrapper, args=(X,y)) + res.append((result,this_tuple)) + + del self.X[this_tuple] + del self.y[this_tuple] + + pool.close() + pool.join() + for result,this_tuple in res: + self.classifiers[this_tuple].model = result.get() + # print(result.get()) + + no_models = [this_tuple for this_tuple in self.classifiers if self.classifiers[this_tuple] == None] + + if no_models: + print("Not able to learn about: %d/%d" % (len(no_models), total_num))
+ # print(len(no_models)) + # print ", ".join(map(str, no_models)) + +
[docs] def decode(self): + # run the classifiers on self.X, return results + results = {} + for this_tuple in self.classifiers: + if this_tuple not in self.X : + print("warning: Did not collect features for ", this_tuple) + continue + n = len(self.X[this_tuple]) + if self.classifiers[this_tuple] == None : + results[this_tuple] = numpy.zeros((n,)) + continue + baseXs = [self.baseXs[index] for index in self.baseX_pointers[this_tuple]] + X = toSparse(baseXs, self.X[this_tuple], self.dictionary) + results[this_tuple] = self.classifiers[this_tuple].predict(X) + return results
+ + +
[docs] def decodeToFile(self, dw, output_fname, config=None): + if config == None : + config = self.config + t0 = time.time() + results = { + "wall-time":0.0, # add later + "dataset": dw.datasets, + "sessions": [] + } + log_input_key = "batch" + if config.has_option("decode","log_input_key") : + log_input_key = config.get("decode","log_input_key") + + # self.extractFeatures(dw,log_input_key=log_input_key) + # decode_results = self.decode() + # counter = defaultdict(int) + total_calls = len(dw.session_list) + for call_num, call in enumerate(dw): + print('[%d/%d]' % (call_num, total_calls)) + session = {"session-id" : call.log["session-id"], "turns":[]} + for log_turn, _ in call: + slu_hyps = self.decode_sent(log_turn['input']['live'],config) + # active_tuples = self.tuples.activeTuples(log_turn) + # tuple_distribution = {} + # for this_tuple in active_tuples: + # index = counter[this_tuple] + # p = decode_results[this_tuple][index] + # tuple_distribution[Tuples.generic_to_specific(this_tuple)] = p + # # check we are decoding the right utterance + # # assert self.fnames[this_tuple][index] == log_turn["input"]["audio-file"] + # counter[this_tuple] += 1 + # slu_hyps = self.tuples.distributionToNbest(tuple_distribution) + session["turns"].append({ + "utterance": log_turn['input']['live']['asr-hyps'][0]['asr-hyp'], + "predict":slu_hyps[0]['slu-hyp'] + }) + results["sessions"].append(session) + + results["wall-time"] =time.time() - t0 + output_file = open(output_fname, "wb") + json.dump(results, output_file, indent=4) + output_file.close()
+ + +
[docs] def decode_sent(self, sentinfo, output_fname, config=None): + if config == None : + config = self.config + t0 = time.time() + self.X = {} + self.y = {} + self.baseXs = [] + self.baseX_pointers = {} + self.fnames = {} + log_input_key = "batch" + if config.has_option("decode","log_input_key") : + log_input_key = config.get("decode","log_input_key") + + self.extractFeatures2(sentinfo,log_input_key=log_input_key) + decode_results = self.decode() + counter = defaultdict(int) + + active_tuples = self.tuples.activeTuples_sent(sentinfo) + tuple_distribution = {} + for this_tuple in active_tuples: + index = counter[this_tuple] + assert len(decode_results[this_tuple])==1 + if len(decode_results[this_tuple]) - 1 < index: + p = 0 + else: + p = decode_results[this_tuple][index] + # p = decode_results[this_tuple][index] + tuple_distribution[Tuples.generic_to_specific(this_tuple)] = p + # check we are decoding the right utterance + counter[this_tuple] += 1 + slu_hyps = self.tuples.distributionToNbest(tuple_distribution) + + return slu_hyps
+ + + + +
[docs] def save(self, save_fname): + classifier_params = {} + for this_tuple in self.classifiers: + if self.classifiers[this_tuple] == None : + classifier_params[this_tuple] = None + else : + print('saving: ',this_tuple) + classifier_params[this_tuple] = self.classifiers[this_tuple].params() + + obj = { + "classifier_params":classifier_params, + "dictionary":self.dictionary + } + save_file = open(save_fname, "wb") + pickle.dump(obj, save_file) + save_file.close()
+ + +
[docs] def load(self, fname): + rootpath=os.path.dirname(os.path.abspath(__file__)) + fname = os.path.join(rootpath, fname) + print("loading saved Classifier") + print(fname) + obj = pickle.load(open(fname,'rb')) + print("loaded.") + classifier_params = obj["classifier_params"] + self.classifiers = {} + for this_tuple in classifier_params: + if classifier_params[this_tuple] == None : + self.classifiers[this_tuple] = None + else : + self.classifiers[this_tuple] = names_to_classes[self.type](self.config) + self.classifiers[this_tuple].load(classifier_params[this_tuple]) + + self.dictionary = obj["dictionary"]
+ +
[docs] def export(self, models_fname, dictionary_fname, config_fname): + print("exporting Classifier for Caesar to read") + print("models to be saved in", models_fname) + print("dictionary to be saved in", dictionary_fname) + print("config to be saved in", config_fname) + + if self.type != "svm" : + print("Only know how to export SVMs") + return + lines = [] + for this_tuple in self.classifiers: + if self.classifiers[this_tuple] != None: + t = this_tuple + if Tuples.is_generic(this_tuple[-1]) : + t = this_tuple[:-1] + ("<generic_value>",) + lines += ['('+','.join(t)+')'] + lines += sutils.svm_to_libsvm(self.classifiers[this_tuple].model) + lines += [".",""] + models_savefile = open(models_fname, "wb") + for line in lines: + models_savefile.write(line+"\n") + models_savefile.close() + + # save dictionary + json_dictionary = [] + dictionary_items = self.dictionary.items() + dictionary_items.sort(key = lambda x:x[1]) + assert [x[1] for x in dictionary_items] == range(len(self.dictionary)) + keys = [list(x[0]) for x in dictionary_items] + + json.dump( keys, open(dictionary_fname, "w")) + + + # save config + config_savefile = open(config_fname, "w") + config_savefile.write("# Automatically generated by CNetTrain scripts\n") + options = { + "FEATURES":json.dumps(self.features), + "MAX_ACTIVE_TUPLES":str(self.tuples.max_active), + "TAIL_CUTOFF":str(self.tuples.tail_cutoff), + "MODELS":os.path.join(os.getcwd(), models_fname), + "DICTIONARY":os.path.join(os.getcwd(), dictionary_fname), + + } + if "cnet" in self.features : + index = self.features.index("cnet") + cnf = self.feature_extractors[index] + options["MAX_NGRAM_LENGTH"] = str(cnf.max_length) + options["MAX_NGRAMS"] = str(cnf.max_ngrams) + for key in options: + this_line = "CNET : %s"% key + this_line = this_line.ljust(30) + this_line += "= "+options[key] + config_savefile.write("\t"+this_line+"\n") + config_savefile.close() + print("exported Classifier.")
+ + +
[docs]def toSparse(baseX, X, dictionary): + # convert baseX & X (a list of dictionaries), to a sparse matrix, using dictionary to map to indices + out = lil_matrix((len(X),len(dictionary))) + for i, (basex, x) in enumerate(zip(baseX, X)) : + for key in basex : + if key not in dictionary : + continue + out[i,dictionary[key]] = basex[key] + for key in x : + if key not in dictionary : + continue + out[i,dictionary[key]] = x[key] + + out = out.tocsr() + return out
+ + +# classifiers define : +# train(X,y) +# predict(X) +# params() +# load(params) +# X is a sparse matrix, y is a vector of class labels (ints) +from sklearn import svm +
[docs]class SVM(): + def __init__(self, config): + self.C = 1 + +
[docs] def pickC(self, X, y): + Cs = [1, 0.1, 5, 10, 50] # 1 goes first as it should be preferred + scores = [] + n = X.shape[0] + dev_index = max([int(n*0.8), 1+y.index(1)]) + max_score = 0.0 + self.C = Cs[0] + print("Warning, not picking C from validation") + return + for i, C in enumerate(Cs) : + this_model = svm.sparse.SVC(C=C, kernel='linear') + this_model.probability = False + this_model.class_weight = 'auto' + + this_model.fit(X[:dev_index,:],y[:dev_index]) + pred = this_model.predict(X) + train_correct = 0.0 + dev_correct = 0.0 + for j, y_j in enumerate(y): + if j < dev_index : + train_correct += int(y_j == pred[j]) + else : + dev_correct += int(y_j == pred[j]) + train_acc = train_correct/dev_index + dev_acc = dev_correct/(n-dev_index) + score = (0.1*train_acc + 0.9*dev_acc) + print("\tfor C=%.2f;\n\t\t train_acc=%.4f, dev_acc=%.4f, score=%.4f" % (C, train_acc, dev_acc, score)) + if score > max_score : + max_score = score + self.C = C + if score == 1.0 : + break + print("Selected C=%.2f"%self.C)
+ + +
[docs] def train(self, X, y): + # print('train') + # print(X[0]) + # print(type(X[0])) + # print(numpy.shape(X)) + # print(y[0]) + self.pickC(X, y) + #model = svm.sparse.SVC(kernel='linear', C=self.C) + model = svm.SVC(kernel='linear', C=self.C) + model.probability=True + # model.class_weight = 'auto' + model.fit(X,y) + self.model = model
+ +
[docs] def predict(self, X): + y = self.model.predict_proba(X) + return y[:,1]
+ +
[docs] def params(self, ): + return self.model
+ +
[docs] def load(self, params): + self.model = params
+ +names_to_classes["svm"] = SVM + +from sklearn.linear_model import SGDClassifier +
[docs]class SGD(): + def __init__(self, config): + pass + +
[docs] def train(self, X, y): + model = SGDClassifier(loss="log", penalty="l2") + model.probability=True + model.fit(X,y) + self.model = model
+ +
[docs] def predict(self, X): + y = self.model.predict_proba(X) + return y[:,1]
+ +
[docs] def params(self, ): + return self.model
+ +
[docs] def load(self, params): + self.model = params
+ +names_to_classes["sgd"] = SGD + +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Features.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Features.html new file mode 100644 index 0000000..cd39678 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Features.html @@ -0,0 +1,561 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.svm.Features — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.svm.Features
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.svm.Features

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+
+import itertools
+import math
+import json
+from collections import defaultdict
+
+from convlab.modules.nlu.multiwoz.svm import Tuples
+
+
+
[docs]class lastSys(object): + def __init__(self, config): + pass + +
[docs] def calculate(self, log_turn,log_input_key="batch"): + acts = log_turn["output"]["dialog-acts"] + out = defaultdict(float) + for act in acts: + act_type =act["act"] + out[(act_type,)] += 1 + for slot,value in act["slots"]: + if act_type == "request" : + out[("request", value)] += 1 + else : + out[(act_type,slot)] += 1 + out[(act_type, slot, value)]+=1 + out[(slot,value)]+=1 + return out
+ +
[docs] def tuple_calculate(self, this_tuple, log_turn,log_input_key="batch"): + return {}
+ + + +
[docs]class valueIdentifying(object): + def __init__(self, config): + pass + +
[docs] def calculate(self, log_turn,log_input_key="batch"): + return {}
+ +
[docs] def tuple_calculate(self, this_tuple, log_turn,log_input_key="batch"): + if Tuples.is_generic(this_tuple[-1]) : + return {"<generic_value="+this_tuple[-1].value+">":1} + else : + return {}
+ +
[docs]class nbest(object): + def __init__(self, config): + self.max_length = 3 + if config.has_option("classifier", "max_ngram_length") : + self.max_length = int(config.get("classifier", "max_ngram_length")) + self.skip_ngrams = False + if config.has_option("classifier","skip_ngrams") : + self.skip_ngrams = config.get("classifier","skip_ngrams")=="True" + self.skip_ngram_decay = 0.9 + if config.has_option("classifier","skip_ngram_decay") : + self.skip_ngram_decay = float(config.get("classifier","skip_ngram_decay")) + self.max_ngrams = 200 + if config.has_option("classifier", "max_ngrams") : + self.max_ngrams = int(config.get("classifier", "max_ngrams")) + +
[docs] def calculate(self, log_turn,log_input_key="batch"): + + asr_hyps = [(hyp["score"],hyp["asr-hyp"]) for hyp in log_turn["input"][log_input_key]["asr-hyps"]] + asr_hyps = [(score, hyp) for score,hyp in asr_hyps if score > -100] + # do exp of scores and normalise + if (len(asr_hyps) == 0): + return {} + + min_score = min([score for score, _hyp in asr_hyps]) + + asr_hyps = [(math.exp(score+min_score), hyp) for score, hyp in asr_hyps] + total_p = sum([score for score, _hyp in asr_hyps]) + + if total_p == 0: + print(asr_hyps) + asr_hyps = [(score/total_p, hyp) for score, hyp in asr_hyps] + + ngrams = defaultdict(float) + + for p, asr_hyp in asr_hyps: + these_ngrams = get_ngrams(asr_hyp.lower(), self.max_length, skip_ngrams=self.skip_ngrams) + for ngram, skips in these_ngrams : + skip_decay = 1.0 + for skip in skips: + skip *= (self.skip_ngram_decay**(skip-1)) + ngrams[ngram]+=p * skip_decay + + self.final_ngrams = ngrams.items() + self.final_ngrams = sorted(self.final_ngrams, key = lambda x:-x[1]) + self.final_ngrams = self.final_ngrams[:self.max_ngrams] + return ngrams
+ +
[docs] def calculate_sent(self, log_turn,log_input_key="batch"): + + asr_hyps = [(hyp["score"],hyp["asr-hyp"]) for hyp in log_turn["asr-hyps"]] + asr_hyps = [(score, hyp) for score,hyp in asr_hyps if score > -100] + # do exp of scores and normalise + if (len(asr_hyps) == 0): + return {} + + min_score = min([score for score, _hyp in asr_hyps]) + + asr_hyps = [(math.exp(score+min_score), hyp) for score, hyp in asr_hyps] + total_p = sum([score for score, _hyp in asr_hyps]) + + if total_p == 0: + print(asr_hyps) + asr_hyps = [(score/total_p, hyp) for score, hyp in asr_hyps] + + ngrams = defaultdict(float) + + for p, asr_hyp in asr_hyps: + these_ngrams = get_ngrams(asr_hyp.lower(), self.max_length, skip_ngrams=self.skip_ngrams) + for ngram, skips in these_ngrams : + skip_decay = 1.0 + for skip in skips: + skip *= (self.skip_ngram_decay**(skip-1)) + ngrams[ngram]+=p * skip_decay + + self.final_ngrams = ngrams.items() + self.final_ngrams = sorted(self.final_ngrams,key = lambda x:-x[1]) + self.final_ngrams = self.final_ngrams[:self.max_ngrams] + return ngrams
+ +
[docs] def tuple_calculate(self, this_tuple, log_turn,log_input_key="batch"): + final_ngrams = self.final_ngrams + # do we need to add generic ngrams? + new_ngrams = [] + + if Tuples.is_generic(this_tuple[-1]) : + gvalue = this_tuple[-1] + for ngram, score in final_ngrams: + if gvalue.value is not None: + if gvalue.value.lower() in ngram : + new_ngram = ngram.replace(gvalue.value.lower(), "<generic_value>") + new_ngrams.append((new_ngram,score)) + + return dict(new_ngrams)
+ + +
[docs]def get_ngrams(sentence, max_length, skip_ngrams=False, add_tags = True): + # return ngrams of length up to max_length as found in sentence. + out = [] + words = sentence.split() + if add_tags : + words = ["<s>"]+words+["</s>"] + if not skip_ngrams : + for i in range(len(words)): + for n in range(1,min(max_length+1, len(words)-i+1)): + this_ngram = " ".join(words[i:i+n]) + out.append((this_ngram,[])) + else : + for n in range(1, max_length+1): + subsets = set(itertools.combinations(range(len(words)), n)) + for subset in subsets: + subset = sorted(subset) + dists = [(subset[i]-subset[i-1]) for i in range(1, len(subset))] + out.append((" ".join([words[j] for j in subset]), dists)) + + + return out
+ + + + +
[docs]class nbestLengths(object) : + def __init__(self, config): + pass +
[docs] def calculate(self, log_turn,log_input_key="batch"): + out = {} + hyps = [hyp["asr-hyp"] for hyp in log_turn["input"][log_input_key]["asr-hyps"]] + for i, hyp in enumerate(hyps): + out[i] = len(hyp.split()) + return out
+ +
[docs] def tuple_calculate(self, this_tuple, log_turn ,log_input_key="batch"): + return {}
+ +
[docs]class nbestScores(object) : + def __init__(self, config): + pass +
[docs] def calculate(self, log_turn,log_input_key="batch"): + out = {} + scores = [hyp["score"] for hyp in log_turn["input"][log_input_key]["asr-hyps"]] + for i, score in enumerate(scores): + out[i] = score + return out
+ +
[docs] def tuple_calculate(self, this_tuple, log_turn,log_input_key="batch" ): + return {}
+ + +
[docs]class cnet(object): + def __init__(self, config): + self.slots_enumerated = json.loads(config.get("grammar", "slots_enumerated")) + self.max_length = 3 + if config.has_option("classifier", "max_ngram_length") : + self.max_length = int(config.get("classifier", "max_ngram_length")) + self.max_ngrams = 200 + if config.has_option("classifier", "max_ngrams") : + self.max_ngrams = int(config.get("classifier", "max_ngrams")) + self.final_ngrams = None + self.last_parse = None + +
[docs] def calculate(self, log_turn,log_input_key="batch"): + if self.last_parse == log_turn["input"]["audio-file"] : + return dict([(ng.string_repn(), ng.score()) for ng in self.final_ngrams]) + cnet = log_turn["input"][log_input_key]["cnet"] + self.final_ngrams = get_cnngrams(cnet,self.max_ngrams, self.max_length) + self.last_parse = log_turn["input"]["audio-file"] + return dict([(ng.string_repn(), ng.score()) for ng in self.final_ngrams])
+ + +
[docs] def tuple_calculate(self, this_tuple, log_turn,log_input_key="batch"): + final_ngrams = self.final_ngrams + # do we need to add generic ngrams? + new_ngrams = [] + if Tuples.is_generic(this_tuple[-1]) : + gvalue = this_tuple[-1] + for ngram in final_ngrams: + new_ngram = cn_ngram_replaced(ngram, gvalue.value.lower(), "<generic_value>") + if new_ngram != False: + new_ngrams.append(new_ngram) + + return dict([(ng.string_repn(), ng.score()) for ng in new_ngrams])
+ + + +
[docs]def get_cnngrams(cnet, max_ngrams, max_length): + active_ngrams = [] + finished_ngrams = [] + threshold = -5 + for sausage in cnet: + new_active_ngrams = [] + + for arc in sausage['arcs']: + if arc['score'] < threshold : + continue + this_ngram = cnNgram(arc['word'].lower(), arc['score']) + for ngram in active_ngrams: + + new_ngram = ngram + this_ngram + if len(new_ngram) < max_length : + new_active_ngrams.append(new_ngram) + # don't add ones ending in !NULL to finished + # as they need to end on a real word + # otherwise HELLO, HELLO !NULL, HELLO !NULL !NULL ...will accumulate + if arc['word'] != "!null" : + finished_ngrams.append(new_ngram) + elif arc['word'] != "!null" : + + finished_ngrams.append(new_ngram) + + if len(this_ngram) != 0 : + new_active_ngrams.append(this_ngram) + finished_ngrams.append(this_ngram) + + active_ngrams = cn_ngram_prune((new_active_ngrams[:]), int(1.5*max_ngrams)) + + return cn_ngram_prune(cn_ngram_merge(finished_ngrams), max_ngrams)
+ + +
[docs]class cnNgram(object): + + def __init__(self, words, logp, delta=0): + if not isinstance(words, type([])) : + words = words.split() + self.words = words + self.logp = logp + self.active = True + self.replacement_length_delta = delta + + +
[docs] def logscore(self): + return self.logp / len(self)
+ +
[docs] def score(self): + return math.exp(self.logscore())
+ + + def __add__(self, other): + return cnNgram(self.words + other.words, self.logp+other.logp) + + def __repr__(self, ): + return "%s : %.7f" % (" ".join(self.words), self.logp) + + def __len__(self): + return len([x for x in self.words if x != "!null"]) + self.replacement_length_delta + +
[docs] def word_list(self, ): + return [word for word in self.words if word != "!null"]
+ +
[docs] def string_repn(self, ): + return " ".join(self.word_list())
+ + + def __hash__(self): + # means sets work + string = self.string_repn() + return string.__hash__() + + def __eq__(self, other): + return self.string_repn() == other.string_repn()
+ +
[docs]def cn_ngram_merge(ngrams) : + # merge a list of ngrams + merged = {} + for ngram in ngrams: + if ngram not in merged : + merged[ngram] = ngram.logp + else : + merged[ngram] = math.log( math.exp(ngram.logp) + math.exp(merged[ngram]) ) + + new_ngrams = [] + for ngram in merged: + ngram.logp = merged[ngram] + new_ngrams.append(ngram) + return new_ngrams
+ +
[docs]def cn_ngram_prune(ngrams, n): + if len(ngrams) < n : + return ngrams + ngrams.sort(key=lambda x:-x.logscore()) + return ngrams[:n]
+ +
[docs]def cn_ngram_replaced(ngram, searchwords, replacement): + words = ngram.word_list() + searchwords = searchwords.split() + new_words = [] + found = False + i=0 + while i < len(words): + if words[i:i+len(searchwords)] == searchwords: + new_words.append(replacement) + found = True + i+=len(searchwords) + else : + new_words.append(words[i]) + i+=1 + if not found : + return False + out = cnNgram(new_words, ngram.logp, delta=len(searchwords) - 1) + return out
+ + + + + +if __name__ == '__main__': + cn = [ + {"arcs":[{"word":"<s>","score":0.0}]}, + {"arcs":[{"word":"hi","score":0.0}]}, + {"arcs":[{"word":"there","score":-math.log(2)}, {"word":"!null","score":-math.log(2)}]}, + {"arcs":[{"word":"how","score":0.0}]}, + {"arcs":[{"word":"are","score":0.0}]}, + {"arcs":[{"word":"you","score":0.0}]}, + {"arcs":[{"word":"</s>","score":0.0}]} + + ] + final_ngrams = get_cnngrams(cn,200,3) + print(dict([(ng.string_repn(), ng.score()) for ng in final_ngrams])) + import configparser, json, Tuples + config = configparser.ConfigParser() + config.read("output/experiments/feature_set/run_1.cfg") + nb = cnet(config) + log_file = json.load(open("corpora/data/Mar13_S2A0/voip-318851c80b-20130328_224811/log.json")) + log_turn = log_file["turns"][2] + print(nb.calculate( + log_turn + )) + tup = ("inform", "food", Tuples.genericValue("food", "modern european")) + print(nb.tuple_calculate(tup, log_turn)) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Tuples.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Tuples.html new file mode 100644 index 0000000..850bbc0 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/Tuples.html @@ -0,0 +1,492 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.svm.Tuples — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.svm.Tuples
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.svm.Tuples

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+# deal with tuples and dialog acts
+import math
+import os
+import re
+import json
+
+from convlab.modules.nlu.multiwoz.svm import sutils
+
+
+
[docs]class tuples(object): + def __init__(self, config): + self.acts = json.loads(config.get("grammar", "acts")) + self.nonempty_acts = json.loads(config.get("grammar", "nonempty_acts")) + self.nonfull_acts = [act for act in self.acts if act not in self.nonempty_acts] + + rootpath=os.path.dirname(os.path.abspath(__file__)) + # if "semi" not in rootpath: + # rootpath+="/semi/CNetTrain/" + # else: + # rootpath+="/CNetTrain/" + self.ontology = json.load( + open(rootpath+'/'+config.get("grammar", "ontology")) + ) + + self.slots_informable = self.ontology["informable"] + self.slots = self.ontology["requestable"] + + self.slots_enumerated = json.loads(config.get("grammar", "slots_enumerated")) + self.config = config + self.all_tuples = self._getAllTuples() + self.max_active = 10 + if config.has_option("decode","max_active_tuples") : + self.max_active = int(config.get("decode","max_active_tuples")) + + self.tail_cutoff = 0.001 + if config.has_option("decode","tail_cutoff") : + self.tail_cutoff = float(config.get("decode","tail_cutoff")) + self.log_tail_cutoff = math.log(self.tail_cutoff) + + +
[docs] def uactsToTuples(self, uacts): + out = [] + for uact in uacts: + act =uact["act"] + if uact["slots"] == [] : + out.append((act,)) + for slot,value in uact["slots"]: + if act == "request" : + out.append(("request", value)) + elif slot in self.slots_informable or slot == "this": + if slot in self.slots_enumerated or slot == "this": + out.append((act,slot,value)) + else : + out.append((act,slot, genericValue(slot, value))) + return out
+ + def _getAllTuples(self): + out = [] + for slot in self.slots: + out.append(("request", slot)) + for x in self.ontology["all_tuples"]: + slot = x[1] + if slot in self.slots_enumerated: + out.append(tuple(x)) + else: + out.append((x[0], slot, genericValue(slot))) + out.append((x[0], slot, "do n't care")) + # all_tuples = [] + # for x in self.ontology["all_tuples"]: + # if x[0]=='request': + # all_tuples.append(tuple(x)) + # else: + # slot = x[1] + # if slot in self.slots_enumerated or slot == "this": + # all_tuples.append(tuple(x)) + # else: + # all_tuples.append((x[0],x[1],genericValue(x[1], x[2]))) + # return all_tuples + + # out = [] + # for slot in self.slots: + # out.append(("request", slot)) + # for act in self.nonempty_acts: + # if act == "request" : + # continue + # for slot in self.slots_informable: + # if slot in self.slots_enumerated : + # for value in self.ontology["informable"][slot] : + # out.append((act,slot,value)) + # + # else : + # out.append((act,slot, genericValue(slot))) + # out.append((act, slot, "do nt care")) + # for slot in self.slots_informable: + # out.append(("inform",slot,"do nt care")) + + # for act in self.nonfull_acts: + # out.append((act,)) + return list(set(out)) + +
[docs] def activeTuples(self, log_turn): + asr_hyps = log_turn["input"]["live"]["asr-hyps"] + out = [] + asr_hyps_conc = ", ".join([asr_hyp['asr-hyp'].lower() for asr_hyp in asr_hyps]) + for this_tuple in self.all_tuples: + if is_generic(this_tuple[-1]) : + # this is a generic value + act, slot, gvalue = this_tuple + for value in self.ontology["informable"][this_tuple[-2]]: + if value.lower() in asr_hyps_conc : + out.append((act, slot, genericValue(slot, value))) + if slot == 'Phone': + matchObj = re.search(r'\d{11}',asr_hyps_conc) + if matchObj: + out.append((act, slot, genericValue(slot, matchObj.group()))) + elif slot == 'Ticket': + matchObj = re.search(r'([0-9.]*?) (GBP|gbp)', asr_hyps_conc) + if matchObj: + out.append((act, slot, genericValue(slot, matchObj.group()))) + elif slot == 'Ref': + matchObj = re.search(r'reference number is(\s*?)([a-zA-Z0-9]+)', asr_hyps_conc) + if matchObj: + out.append((act, slot, genericValue(slot, matchObj.group(2)))) + elif slot == 'Time' or slot == 'Arrive' or slot == 'Leave': + matchObj = re.search(r'\d+?:\d\d', asr_hyps_conc) + if matchObj: + out.append((act, slot, genericValue(slot, matchObj.group(0)))) + else : + out.append(this_tuple) + return out
+ +
[docs] def activeTuples_sent(self, log_turn): + asr_hyps = log_turn["asr-hyps"] + out = [] + asr_hyps_conc = ", ".join([asr_hyp['asr-hyp'].lower() for asr_hyp in asr_hyps]) + for this_tuple in self.all_tuples: + if is_generic(this_tuple[-1]) : + # this is a generic value + act, slot, gvalue = this_tuple + for value in self.ontology["informable"][this_tuple[-2]]: + if value.lower() in asr_hyps_conc : + out.append((act, slot, genericValue(slot, value))) + if slot == 'Phone': + matchObj = re.search(r'\d{11}',asr_hyps_conc) + if matchObj: + out.append((act, slot, genericValue(slot, matchObj.group()))) + elif slot == 'Ticket': + matchObj = re.search(r'([0-9.]*?) (GBP|gbp)', asr_hyps_conc) + if matchObj: + out.append((act, slot, genericValue(slot, matchObj.group()))) + elif slot == 'Ref': + matchObj = re.search(r'reference number is(\s*?)([a-zA-Z0-9]+)', asr_hyps_conc) + if matchObj: + out.append((act, slot, genericValue(slot, matchObj.group(2)))) + elif slot == 'Time' or slot == 'Arrive' or slot == 'Leave': + matchObj = re.search(r'\d+?:\d\d', asr_hyps_conc) + if matchObj: + out.append((act, slot, genericValue(slot, matchObj.group(0)))) + else : + out.append(this_tuple) + return out
+ +
[docs] def distributionToNbest(self, tuple_distribution): + # convert a tuple distribution to an nbest list + tuple_distribution = tuple_distribution.items() + output = [] + ps = [p for _t,p in tuple_distribution] + eps = 0.00001 + tuple_distribution = [(t, math.log(max(eps,p)), math.log(max(eps, 1-p))) for t,p in tuple_distribution if p > 0] + tuple_distribution = sorted(tuple_distribution,key=lambda x:-x[1]) + # prune + tuple_distribution = tuple_distribution[:self.max_active] + + n = len(tuple_distribution) + powerset = sutils.powerset(range(n)) + acts = [] + for subset in powerset: + act = [] + score = 0 + for i in range(n): + this_tuple, logp, log1_p = tuple_distribution[i] + if i in subset : + act.append(this_tuple) + score += logp + else : + score += log1_p + if (score> self.log_tail_cutoff or len(act) == 0) and makes_valid_act(act) : + acts.append((act,score)) + if len(act) ==0 : + null_score = score + acts = sorted(acts,key=lambda x:-x[1]) + + acts = acts[:10] + found_null = False + for act,score in acts: + if len(act) == 0: + found_null = True + break + if not found_null : + acts.append(([], null_score)) + + #normalise + acts = [(act,math.exp(logp)) for act,logp in acts] + totalp = sum([p for act,p in acts]) + acts = [{"slu-hyp":[tuple_to_act(a) for a in act],"score":p/totalp} for act,p in acts] + return acts
+ +
[docs]def tuple_to_act(t) : + if len(t) == 1 : + return {"act":t[0],"slots":[]} + if len(t) == 2 : + assert t[0] == "request" + return {"act":"request", "slots":[["slot",t[1]]]} + else : + return {"act":t[0],"slots":[[t[1],t[2]]]}
+ + + +
[docs]def makes_valid_act(tuples): + # check if uacts is a valid list of tuples + # - can't affirm and negate + # - can't deny and inform same thing + # - can't inform(a=x) inform(a=y) if x!=u + singles = [t for t in tuples if len(t)==1] + if ("affirm",) in tuples and ("negate",) in tuples : + return False + triples = [t for t in tuples if len(t)==3] + informed = [(slot, value) for act,slot,value in triples if act=="inform"] + denied = [(slot, value) for act,slot,value in triples if act=="deny" ] + for s,v in informed: + if (s,v) in denied: + return False + informed_slots = [slot for slot, _value in informed] + if len(informed_slots) != len(set(informed_slots)) : + return False + return True
+ +
[docs]def actual_value(value): + try: + return value.value + except AttributeError: + return value
+ + +
[docs]class genericValue(object): + # useful class to use to represent a generic value + # x = genericValue("food") + # y = genericValue("food","chinese") + # z = genericValue("food","indian") + # x == y + # y in [x] + # y.value != z.value + + def __init__(self, slot, value=None): + self.slot = slot + self.value = value + + def __str__(self): + paren = "" + if self.value is not None : + paren = " (%s)" % self.value + return ("(generic value for %s"% self.slot) + paren + ")" + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + try: + return self.slot == other.slot + except AttributeError : + return False + + def __hash__(self): + return self.slot.__hash__()
+ + +
[docs]def is_generic(value): + return not isinstance(value, str)
+ +
[docs]def generic_to_specific(tup) : + if len(tup) == 3 : + act,slot,value = tup + value = actual_value(value) + return (act,slot,value) + return tup
+ +if __name__ == '__main__': + + import configparser, json + + config = configparser.ConfigParser() + config.read("config/multiwoz.cfg") + t = tuples(config) + dist = {('inform', 'food','indian'):0.9,('inform', 'food','indian2'):1.0, ('hello',):0.1} + print(dist) + nbest = t.distributionToNbest(dist) + print(nbest) + + log_file = json.load(open("corpora/data/Mar13_S2A0/voip-318851c80b-20130328_224811/log.json")) + log_turn = log_file["turns"][2] + print(log_turn["input"]["batch"]["asr-hyps"][0]) + print([tup for tup in t.activeTuples(log_turn) if tup[0] == "inform"]) + + +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/nlu.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/nlu.html new file mode 100644 index 0000000..4f50c79 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/nlu.html @@ -0,0 +1,271 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.svm.nlu — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.svm.nlu
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.svm.nlu

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import configparser
+import os
+import zipfile
+
+from convlab.lib.file_util import cached_path
+from convlab.modules.nlu.multiwoz.svm import Classifier
+from convlab.modules.nlu.nlu import NLU
+
+
+
[docs]class SVMNLU(NLU): + def __init__(self, + config_file=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config/multiwoz.cfg'), + model_file=None): + self.config = configparser.ConfigParser() + self.config.read(config_file) + self.c = Classifier.classifier(self.config) + model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), self.config.get("train", "output")) + model_dir = os.path.dirname(model_path) + if not os.path.exists(model_path): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + if not model_file: + print('Load from ', os.path.join(model_dir, 'svm_multiwoz.zip')) + archive = zipfile.ZipFile(os.path.join(model_dir, 'svm_multiwoz.zip'), 'r') + else: + print('Load from model_file param') + archive_file = cached_path(model_file) + archive = zipfile.ZipFile(archive_file, 'r') + archive.extractall(os.path.dirname(model_dir)) + archive.close() + self.c.load(model_path) + +
[docs] def parse(self, utterance, context=None, not_empty=True): + sentinfo = { + "turn-id": 0, + "asr-hyps": [ + { + "asr-hyp": utterance, + "score": 0 + } + ] + } + slu_hyps = self.c.decode_sent(sentinfo, self.config.get("decode", "output")) + if not_empty: + act_list = [] + for hyp in slu_hyps: + if hyp['slu-hyp']: + act_list = hyp['slu-hyp'] + break + else: + act_list = slu_hyps[0]['slu-hyp'] + dialog_act = {} + for act in act_list: + intent = act['act'] + if intent=='request': + domain, slot = act['slots'][0][1].split('-') + intent = domain+'-'+intent.capitalize() + dialog_act.setdefault(intent,[]) + dialog_act[intent].append([slot,'?']) + else: + dialog_act.setdefault(intent, []) + dialog_act[intent].append(act['slots'][0]) + return dialog_act
+ +if __name__ == "__main__": + nlu = SVMNLU() + test_utterances = [ + "What type of accommodations are they. No , i just need their address . Can you tell me if the hotel has internet available ?", + "What type of accommodations are they.", + "No , i just need their address .", + "Can you tell me if the hotel has internet available ?" + "you're welcome! enjoy your visit! goodbye.", + "yes. it should be moderately priced.", + "i want to book a table for 6 at 18:45 on thursday", + "i will be departing out of stevenage.", + "What is the Name of attraction ?", + "Can I get the name of restaurant?", + "Can I get the address and phone number of the restaurant?", + "do you have a specific area you want to stay in?" + ] + for utt in test_utterances: + print(utt) + print(nlu.parse(utt)) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/preprocess.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/preprocess.html new file mode 100644 index 0000000..844899b --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/preprocess.html @@ -0,0 +1,343 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.svm.preprocess — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.svm.preprocess
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.svm.preprocess

+import json
+import os
+import zipfile
+from collections import Counter
+
+
+
[docs]def read_zipped_json(filepath, filename): + archive = zipfile.ZipFile(filepath, 'r') + return json.load(archive.open(filename))
+ + +if __name__ == '__main__': + data_dir = '../../../../../data/multiwoz' + processed_data_dir = 'corpora/data' + data_key = ['val', 'test', 'train'] + data = {} + for key in data_key: + data[key] = read_zipped_json(os.path.join(data_dir,key+'.json.zip'), key+'.json') + print('load {}, size {}'.format(key, len(data[key]))) + + db_dir = '../../../../../data/multiwoz/db' + db = { + 'attraction': json.load(open(os.path.join(db_dir,'attraction_db.json'))), + 'hotel': json.load(open(os.path.join(db_dir,'hotel_db.json'))), + 'restaurant': json.load(open(os.path.join(db_dir,'restaurant_db.json'))), + 'police': json.load(open(os.path.join(db_dir,'police_db.json'))), + 'hospital': json.load(open(os.path.join(db_dir,'hospital_db.json'))), + 'taxi': json.load(open(os.path.join(db_dir,'taxi_db.json'))), + 'train': json.load(open(os.path.join(db_dir,'train_db.json'))) + } + domain2slot2value = {} + for domain in db.keys(): + domain2slot2value[domain] = {} + if domain == 'taxi': + continue + for item in db[domain]: + for s, v in item.items(): + if isinstance(v, type(u'')): + domain2slot2value[domain].setdefault(s, Counter()) + domain2slot2value[domain][s] += Counter([v]) + else: + domain2slot2value[domain].setdefault(s, []) + domain2slot2value[domain][s].append(v) + + requestable_slots = [] + informable_slots = [] + for no, sess in data['train'].items(): + for i, turn in enumerate(sess['log']): + for da, svs in turn['dialog_act'].items(): + if 'Request' in da: + requestable_slots.extend([da.split('-')[0] + '-' + s for s, v in svs]) + else: + informable_slots.extend([s for s, v in svs]) + requestable_slots = list(set(requestable_slots)) + informable_slots = list(set(informable_slots)) + + + def slot2all_value(slot): + all_value = [] + for domain in domain2slot2value.keys(): + if slot in domain2slot2value[domain]: + all_value.extend(list(domain2slot2value[domain][slot])) + return list(set(all_value)) + + + informable_onto = {} + informable_onto['Fee'] = slot2all_value('entrance fee') + informable_onto['Addr'] = slot2all_value('address') + informable_onto['Area'] = slot2all_value('area') + informable_onto['Stars'] = slot2all_value('stars') + ['zero', 'one', 'two', 'three', 'four', 'five'] + informable_onto['Internet'] = slot2all_value('internet') + informable_onto['Department'] = slot2all_value('department') + informable_onto['Stay'] = list([str(i) for i in range(10)]) + ['zero', 'one', 'two', 'three', 'four', 'five', 'six', + 'seven', 'eight', 'nine', 'ten'] + informable_onto['Ref'] = [] + informable_onto['Food'] = slot2all_value('food') + informable_onto['Type'] = slot2all_value('type') + informable_onto['Price'] = slot2all_value('pricerange') + informable_onto['Choice'] = list([str(i) for i in range(20)]) + ['zero', 'one', 'two', 'three', 'four', 'five', + 'six', 'seven', 'eight', 'nine', 'ten'] + informable_onto['Phone'] = [] + informable_onto['Ticket'] = list(domain2slot2value['train']['price']) + informable_onto['Day'] = slot2all_value('day') + informable_onto['Name'] = slot2all_value('name') + informable_onto['Car'] = [i + ' ' + j for i in ["black", "white", "red", "yellow", "blue", "grey"] for j in + ["toyota", "skoda", "bmw", "honda", "ford", "audi", "lexus", "volvo", "volkswagen", + "tesla"]] + informable_onto['Leave'] = [] + informable_onto['Time'] = slot2all_value('Duration') + informable_onto['Arrive'] = [] + informable_onto['Post'] = slot2all_value('postcode') + informable_onto['none'] = ['none'] + informable_onto['Depart'] = slot2all_value('departure') + informable_onto['People'] = list([str(i) for i in range(10)]) + ['zero', 'one', 'two', 'three', 'four', 'five', + 'six', 'seven', 'eight', 'nine', 'ten'] + informable_onto['Dest'] = slot2all_value('destination') + informable_onto['Parking'] = slot2all_value('parking') + informable_onto['Id'] = slot2all_value('trainID') + # remove `Open` + + da2slot2value = {} + for d_key, d in data.items(): + d_dir = os.path.join(processed_data_dir, d_key) + if not os.path.exists(d_dir): + os.makedirs(d_dir) + for no, sess in d.items(): + label_json = {"session-id": no} + label_turns = [] + log_json = {"session-id": no} + log_turns = [] + for i, turn in enumerate(sess['log']): + assert isinstance(turn['dialog_act'], type({})) + new_das = [] + for da, svs in turn['dialog_act'].items(): + for s, v in svs: + if s == 'Open': + continue + if 'Request' in da: + domain, act = da.split('-') + new_das.append({'act': 'request', 'slots': [['slot', domain + '-' + s]]}) + else: + new_das.append({'act': da, 'slots': [[s, v.lower()]]}) + da2slot2value.setdefault(da, {}) + da2slot2value[da].setdefault(s, []) + da2slot2value[da][s].append(v) + label_turns.append({'semantics': {'json': new_das}}) + log_turn = {'input': {'live': {'asr-hyps': [{'asr-hyp': turn['text'], 'score': 0}]}}} + log_turns.append(log_turn) + label_json['turns'] = label_turns + log_json['turns'] = log_turns + f_dir = os.path.join(d_dir,no) + if not os.path.exists(f_dir): + os.makedirs(f_dir) + + json.dump(label_json, open(os.path.join(f_dir,'label.json'), 'w'), indent=4) + json.dump(log_json, open(os.path.join(f_dir,'log.json'), 'w'), indent=4) + + all_tuples = [] + slots_enumerated = ["Area", "Type", "Price", "Day", "Internet", "none", "Parking"] + for da, sv in da2slot2value.items(): + if 'Request' in da: + pass + else: + for s, v in sv.items(): + v_cnt = Counter(v) + if s not in slots_enumerated: + all_tuples.append((da, s)) + else: + for i, c in dict(v_cnt).items(): + if c > 0 and i in informable_onto[s] or i.lower() in informable_onto[s]: + all_tuples.append((da, s, i)) + + ontology_multiwoz = { + "requestable": requestable_slots, + "informable": informable_onto, + "all_tuples": all_tuples + } + json.dump(ontology_multiwoz, open('corpora/scripts/config/ontology_multiwoz.json', 'w'), indent=4) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/sutils.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/sutils.html new file mode 100644 index 0000000..3e9e273 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/sutils.html @@ -0,0 +1,261 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.svm.sutils — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.svm.sutils
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.svm.sutils

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+# misc useful functions
+
+import imp
+import os
+
+
+
[docs]def dataset_walker(dataset=None, dataroot=None, labels=None): + # we assume that the dataset_walker class in dataroot/../scripts + # is the one to use + scripts_folder = os.path.join(dataroot, '../..', "scripts") + # print(scripts_folder) + _dw = imp.load_source('dataset_walker', os.path.join(scripts_folder, "dataset_walker.py")) + return _dw.dataset_walker(dataset, dataroot=dataroot, labels=labels)
+ +
[docs]def import_class(cl): + d = cl.rfind(".") + classname = cl[d+1:len(cl)] + m = __import__(cl[0:d], globals(), locals(), [classname]) + return getattr(m, classname)
+ +from itertools import chain, combinations + +
[docs]def powerset(iterable): + "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
+ + +
[docs]def svm_to_libsvm(model, labels=None) : + # convert an sklearn model object into a file in the format of LIBSVM's sparse SVMs + # (actually return the files lines in an array) + lines = [] + n_classes = model.coef_.shape[0]+1 + total_n_SV, n_feats= model.support_vectors_.shape + n_SV = model.n_support_ + + SV = model.support_vectors_ + + dual_coef = model.dual_coef_.todense() + b = model.intercept_ + + probA = model.probA_ + probB = model.probB_ + + lines.append("svm_type") + lines.append("nr_class %i" % n_classes) + lines.append("total_sv %i" % total_n_SV) + + lines.append("rho "+" ".join(["%.12f" % -c for c in b])) + + if labels == None: + labels = map(str, range(n_classes)) + lines.append("label " + " ".join(labels)) + + lines.append("probA "+" ".join(["%.12f" % v for v in probA])) + lines.append("probB "+" ".join(["%.12f" % v for v in probB])) + + lines.append("nr_sv "+" ".join(["%i" % v for v in n_SV])) + + lines.append("SV") + SV = SV.tocsc() + + for i in range( total_n_SV) : + # coefs are in the jth column of coef + this_line = "" + for c in dual_coef[:,i] : + this_line += ("%.12f " % c) + sv = SV[i,:].tocoo() + + for j,v in zip(sv.col, sv.data) : + this_line += ("%i:%.12f " % (j,v)) + lines.append(this_line) + return lines
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/train.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/train.html new file mode 100644 index 0000000..78660ef --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/svm/train.html @@ -0,0 +1,235 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.svm.train — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.svm.train
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.svm.train

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+
+import configparser
+import os
+import pprint
+import sys
+import zipfile
+
+from convlab.modules.nlu.multiwoz.svm import Classifier, sutils
+
+
+
[docs]def train(config): + c = Classifier.classifier(config) + pprint.pprint(c.tuples.all_tuples) + print('All tuples:',len(c.tuples.all_tuples)) + model_path = config.get("train", "output") + model_dir = os.path.dirname(model_path) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + print('output to {}'.format(model_path)) + dataroot = config.get("train", "dataroot") + dataset = config.get("train", "dataset") + dw = sutils.dataset_walker(dataset = dataset, dataroot=dataroot, labels=True) + c = Classifier.classifier(config) + c.cacheFeature(dw) + c.train(dw) + c.save(model_path) + with zipfile.ZipFile(os.path.join(model_dir, 'svm_multiwoz.zip'), 'w', zipfile.ZIP_DEFLATED) as zf: + zf.write(model_path)
+ +
[docs]def usage(): + print("usage:") + print("\t python train.py config/multiwoz.cfg")
+ + +if __name__ == '__main__': + if len(sys.argv) != 2 : + usage() + sys.exit() + + config = configparser.ConfigParser() + try : + config.read(sys.argv[1]) + except Exception as e: + print("Failed to parse file") + print(e) + + train(config) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/multiwoz/utils.html b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/utils.html new file mode 100644 index 0000000..e35dfdf --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/multiwoz/utils.html @@ -0,0 +1,207 @@ + + + + + + + + + + + convlab.modules.nlu.multiwoz.utils — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.nlu.multiwoz.utils
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.nlu.multiwoz.utils

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+'''
+'''
+
+import math
+
+import numpy as np
+
+
+
[docs]def initWeights(n,d): + """ Initialization Strategy """ + #scale_factor = 0.1 + scale_factor = math.sqrt(float(6)/(n + d)) + return (np.random.rand(n,d)*2-1)*scale_factor
+ +
[docs]def mergeDicts(d0, d1): + """ for all k in d0, d0 += d1 . d's are dictionaries of key -> numpy array """ + for k in d1: + if k in d0: d0[k] += d1[k] + else: d0[k] = d1[k]
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/nlu/nlu.html b/docs/build/html/_modules/convlab/modules/nlu/nlu.html new file mode 100644 index 0000000..22e523a --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/nlu/nlu.html @@ -0,0 +1,206 @@ + + + + + + + + + + + convlab.modules.nlu.nlu — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.modules.nlu.nlu

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+
+
[docs]class NLU: + """Base class for NLU model.""" + + def __init__(self): + """ Constructor for NLU class. """ + +
[docs] def parse(self, utterance, context=None): + """ + Predict the dialog act of a natural language utterance and apply error model. + Args: + utterance (str): The user input, a natural language utterance. + Returns: + output (dict): The parsed dialog act of the input NL utterance. + """ + pass
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/rule_based_multiwoz_bot.html b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/rule_based_multiwoz_bot.html new file mode 100644 index 0000000..4795257 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/rule_based_multiwoz_bot.html @@ -0,0 +1,865 @@ + + + + + + + + + + + convlab.modules.policy.system.multiwoz.rule_based_multiwoz_bot — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.system.multiwoz.rule_based_multiwoz_bot
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.system.multiwoz.rule_based_multiwoz_bot

+import copy
+import json
+import random
+from copy import deepcopy
+
+from convlab.modules.policy.system.policy import SysPolicy
+from convlab.modules.util.multiwoz.dbquery import query
+from convlab.modules.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
+
+SELECTABLE_SLOTS = {
+    'Attraction': ['area', 'entrance fee', 'name', 'type'],
+    'Hospital': ['department'],
+    'Hotel': ['area', 'internet', 'name', 'parking', 'pricerange', 'stars', 'type'],
+    'Restaurant': ['area', 'name', 'food', 'pricerange'],
+    'Taxi': [],
+    'Train': [],
+    'Police': [],
+}
+
+INFORMABLE_SLOTS = ["Fee", "Addr", "Area", "Stars", "Internet", "Department", "Choice", "Ref", "Food", "Type", "Price",\
+                    "Stay", "Phone", "Post", "Day", "Name", "Car", "Leave", "Time", "Arrive", "Ticket", None, "Depart",\
+                    "People", "Dest", "Parking", "Open", "Id"]
+
+REQUESTABLE_SLOTS = ['Food', 'Area', 'Fee', 'Price', 'Type', 'Department', 'Internet', 'Parking', 'Stars', 'Type']
+
+# Information required to finish booking, according to different domain.
+booking_info = {'Train': ['People'],
+                'Restaurant': ['Time', 'Day', 'People'],
+                'Hotel': ['Stay', 'Day', 'People']}
+
+# Alphabet used to generate Ref number
+alphabet = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
+
+# Judge if user has confirmed a unique choice, according to different domain
+token = {'Attraction': ['Name', 'Addr', ''],
+         'Hotel': ['Name', ]}
+
+
+
[docs]class RuleBasedMultiwozBot(SysPolicy): + ''' Rule-based bot. Implemented for Multiwoz dataset. ''' + + recommend_flag = -1 + choice = "" + + def __init__(self): + SysPolicy.__init__(self) + self.last_state = {} + +
[docs] def init_session(self): + self.last_state = {}
+ +
[docs] def predict(self, state): + """ + Args: + State, please refer to util/state.py + Output: + DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...} + """ + # print('policy received state: {}'.format(state)) + + if self.recommend_flag != -1: + self.recommend_flag += 1 + + self.kb_result = {} + + DA = {} + + if 'user_action' in state and (len(state['user_action']) > 0): + user_action = state['user_action'] + else: + user_action = check_diff(self.last_state, state) + + # Debug info for check_diff function + + last_state_cpy = copy.deepcopy(self.last_state) + state_cpy = copy.deepcopy(state) + + try: + del last_state_cpy['history'] + except: + pass + + try: + del state_cpy['history'] + except: + pass + ''' + if last_state_cpy != state_cpy: + print("Last state: ", last_state_cpy) + print("State: ", state_cpy) + print("Predicted action: ", user_action) + ''' + + + self.last_state = state + + for user_act in user_action: + domain, intent_type = user_act.split('-') + + # Respond to general greetings + if domain == 'general': + self._update_greeting(user_act, state, DA) + + # Book taxi for user + elif domain == 'Taxi': + self._book_taxi(user_act, state, DA) + + elif domain == 'Booking': + self._update_booking(user_act, state, DA) + + # User's talking about other domain + elif domain != "Train": + self._update_DA(user_act, user_action, state, DA) + + # Info about train + else: + self._update_train(user_act, user_action, state, DA) + + # Judge if user want to book + self._judge_booking(user_act, user_action, DA) + + if 'Booking-Book' in DA: + if random.random() < 0.5: + DA['general-reqmore'] = [] + user_acts = [] + for user_act in DA: + if user_act != 'Booking-Book': + user_acts.append(user_act) + for user_act in user_acts: + del DA[user_act] + + # print("Sys action: ", DA) + + if DA == {}: + return {'general-greet': [['none', 'none']]} + return DA
+ + def _update_greeting(self, user_act, state, DA): + """ General request / inform. """ + _, intent_type = user_act.split('-') + + # Respond to goodbye + if intent_type == 'bye': + if 'general-bye' not in DA: + DA['general-bye'] = [] + if random.random() < 0.3: + if 'general-welcome' not in DA: + DA['general-welcome'] = [] + elif intent_type == 'thank': + DA['general-welcome'] = [] + + def _book_taxi(self, user_act, state, DA): + """ Book a taxi for user. """ + + blank_info = [] + for info in ['departure', 'destination']: + if state['belief_state']['taxi']['semi'] == "": + info = REF_USR_DA['Taxi'].get(info, info) + blank_info.append(info) + if state['belief_state']['taxi']['semi']['leaveAt'] == "" and state['belief_state']['taxi']['semi']['arriveBy'] == "": + blank_info += ['Leave', 'Arrive'] + + + # Finish booking, tell user car type and phone number + if len(blank_info) == 0: + if 'Taxi-Inform' not in DA: + DA['Taxi-Inform'] = [] + car = generate_car() + phone_num = generate_ref_num(11) + DA['Taxi-Inform'].append(['Car', car]) + DA['Taxi-Inform'].append(['Phone', phone_num]) + return + + # Need essential info to finish booking + request_num = random.randint(0, 999999) % len(blank_info) + 1 + if 'Taxi-Request' not in DA: + DA['Taxi-Request'] = [] + for i in range(request_num): + slot = REF_USR_DA.get(blank_info[i], blank_info[i]) + DA['Taxi-Request'].append([slot, '?']) + + def _update_booking(self, user_act, state, DA): + pass + + def _update_DA(self, user_act, user_action, state, DA): + """ Answer user's utterance about any domain other than taxi or train. """ + + domain, intent_type = user_act.split('-') + + constraints = [] + for slot in state['belief_state'][domain.lower()]['semi']: + if state['belief_state'][domain.lower()]['semi'][slot] != "": + constraints.append([slot, state['belief_state'][domain.lower()]['semi'][slot]]) + + kb_result = query(domain.lower(), constraints) + self.kb_result[domain] = deepcopy(kb_result) + + # print("\tConstraint: " + "{}".format(constraints)) + # print("\tCandidate Count: " + "{}".format(len(kb_result))) + # if len(kb_result) > 0: + # print("Candidate: " + "{}".format(kb_result[0])) + + # print(state['user_action']) + # Respond to user's request + if intent_type == 'Request': + if self.recommend_flag > 1: + self.recommend_flag = -1 + self.choice = "" + elif self.recommend_flag == 1: + self.recommend_flag == 0 + if (domain + "-Inform") not in DA: + DA[domain + "-Inform"] = [] + for slot in user_action[user_act]: + if len(kb_result) > 0: + kb_slot_name = REF_SYS_DA[domain].get(slot[0], slot[0]) + if kb_slot_name in kb_result[0]: + DA[domain + "-Inform"].append([slot[0], kb_result[0][kb_slot_name]]) + else: + DA[domain + "-Inform"].append([slot[0], "unknown"]) + # DA[domain + "-Inform"].append([slot_name, state['kb_results_dict'][0][slot[0].lower()]]) + + else: + # There's no result matching user's constraint + # if len(state['kb_results_dict']) == 0: + if len(kb_result) == 0: + if (domain + "-NoOffer") not in DA: + DA[domain + "-NoOffer"] = [] + + for slot in state['belief_state'][domain.lower()]['semi']: + if state['belief_state'][domain.lower()]['semi'][slot] != "" and \ + state['belief_state'][domain.lower()]['semi'][slot] != "do n't care": + slot_name = REF_USR_DA[domain].get(slot, slot) + DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]]) + + p = random.random() + + # Ask user if he wants to change constraint + if p < 0.3: + req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3) + if domain + "-Request" not in DA: + DA[domain + "-Request"] = [] + for i in range(req_num): + slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0]) + DA[domain + "-Request"].append([slot_name, "?"]) + + # There's exactly one result matching user's constraint + # elif len(state['kb_results_dict']) == 1: + elif len(kb_result) == 1: + + # Inform user about this result + if (domain + "-Inform") not in DA: + DA[domain + "-Inform"] = [] + props = [] + for prop in state['belief_state'][domain.lower()]['semi']: + props.append(prop) + property_num = len(props) + if property_num > 0: + info_num = random.randint(0, 999999) % property_num + 1 + random.shuffle(props) + for i in range(info_num): + slot_name = REF_USR_DA[domain].get(props[i], props[i]) + # DA[domain + "-Inform"].append([slot_name, state['kb_results_dict'][0][props[i]]]) + DA[domain + "-Inform"].append([slot_name, kb_result[0][props[i]]]) + + # There are multiple resultes matching user's constraint + else: + p = random.random() + + # Recommend a choice from kb_list + if True: #p < 0.3: + if (domain + "-Inform") not in DA: + DA[domain + "-Inform"] = [] + if (domain + "-Recommend") not in DA: + DA[domain + "-Recommend"] = [] + DA[domain + "-Inform"].append(["Choice", str(len(kb_result))]) + idx = random.randint(0, 999999) % len(kb_result) + # idx = 0 + choice = kb_result[idx] + if domain in ["Hotel", "Attraction", "Police", "Restaurant"]: + DA[domain + "-Recommend"].append(['Name', choice['name']]) + self.recommend_flag = 0 + self.candidate = choice + props = [] + for prop in choice: + props.append([prop, choice[prop]]) + prop_num = min(random.randint(0, 999999) % 3, len(props)) + # prop_num = min(2, len(props)) + random.shuffle(props) + for i in range(prop_num): + slot = props[i][0] + string = REF_USR_DA[domain].get(slot, slot) + if string in INFORMABLE_SLOTS: + DA[domain + "-Recommend"].append([string, str(props[i][1])]) + + # Ask user to choose a candidate. + elif p < 0.5: + prop_values = [] + props = [] + # for prop in state['kb_results_dict'][0]: + for prop in kb_result[0]: + # for candidate in state['kb_results_dict']: + for candidate in kb_result: + if prop not in candidate: + continue + if candidate[prop] not in prop_values: + prop_values.append(candidate[prop]) + if len(prop_values) > 1: + props.append([prop, prop_values]) + prop_values = [] + random.shuffle(props) + idx = 0 + while idx < len(props): + if props[idx][0] not in SELECTABLE_SLOTS[domain]: + props.pop(idx) + idx -= 1 + idx += 1 + if domain + "-Select" not in DA: + DA[domain + "-Select"] = [] + for i in range(min(len(props[0][1]), 5)): + prop_value = REF_USR_DA[domain].get(props[0][0], props[0][0]) + DA[domain + "-Select"].append([prop_value, props[0][1][i]]) + + # Ask user for more constraint + else: + reqs = [] + for prop in state['belief_state'][domain.lower()]['semi']: + if state['belief_state'][domain.lower()]['semi'][prop] == "": + prop_value = REF_USR_DA[domain].get(prop, prop) + reqs.append([prop_value, "?"]) + i = 0 + while i < len(reqs): + if reqs[i][0] not in REQUESTABLE_SLOTS: + reqs.pop(i) + i -= 1 + i += 1 + random.shuffle(reqs) + if len(reqs) == 0: + return + req_num = min(random.randint(0, 999999) % len(reqs) + 1, 2) + if (domain + "-Request") not in DA: + DA[domain + "-Request"] = [] + for i in range(req_num): + req = reqs[i] + req[0] = REF_USR_DA[domain].get(req[0], req[0]) + DA[domain + "-Request"].append(req) + + def _update_train(self, user_act, user_action, state, DA): + trans = {'day': 'Day', 'destination': 'Destination', 'departure': 'Departure'} + constraints = [] + for time in ['leaveAt', 'arriveBy']: + if state['belief_state']['train']['semi'][time] != "": + constraints.append([time, state['belief_state']['train']['semi'][time]]) + + if len(constraints) == 0: + p = random.random() + if 'Train-Request' not in DA: + DA['Train-Request'] = [] + if p < 0.33: + DA['Train-Request'].append(['Leave', '?']) + elif p < 0.66: + DA['Train-Request'].append(['Arrive', '?']) + else: + DA['Train-Request'].append(['Leave', '?']) + DA['Train-Request'].append(['Arrive', '?']) + + if 'Train-Request' not in DA: + DA['Train-Request'] = [] + for prop in ['day', 'destination', 'departure']: + if state['belief_state']['train']['semi'][prop] == "": + slot = REF_USR_DA['Train'].get(prop, prop) + DA["Train-Request"].append([slot, '?']) + else: + constraints.append([prop, state['belief_state']['train']['semi'][prop]]) + + kb_result = query('train', constraints) + self.kb_result['Train'] = deepcopy(kb_result) + + # print(constraints) + # print(len(kb_result)) + if user_act == 'Train-Request': + del(DA['Train-Request']) + if 'Train-Inform' not in DA: + DA['Train-Inform'] = [] + for slot in user_action[user_act]: + # Train_DA_MAP = {'Duration': "Time", 'Price': 'Ticket', 'TrainID': 'Id'} + # slot[0] = Train_DA_MAP.get(slot[0], slot[0]) + slot_name = REF_SYS_DA['Train'].get(slot[0], slot[0]) + try: + DA['Train-Inform'].append([slot[0], kb_result[0][slot_name]]) + except: + pass + return + if len(kb_result) == 0: + if 'Train-NoOffer' not in DA: + DA['Train-NoOffer'] = [] + for prop in constraints: + DA['Train-NoOffer'].append([REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]]) + if 'Train-Request' in DA: + del DA['Train-Request'] + elif len(kb_result) >= 1: + if len(constraints) < 4: + return + if 'Train-Request' in DA: + del DA['Train-Request'] + if 'Train-OfferBook' not in DA: + DA['Train-OfferBook'] = [] + for prop in constraints: + DA['Train-OfferBook'].append([REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]]) + + def _judge_booking(self, user_act, user_action, DA): + """ If user want to book, return a ref number. """ + if self.recommend_flag > 1: + self.recommend_flag = -1 + self.choice = "" + elif self.recommend_flag == 1: + self.recommend_flag == 0 + domain, _ = user_act.split('-') + for slot in user_action[user_act]: + if domain in booking_info and slot[0] in booking_info[domain]: + if 'Booking-Book' not in DA: + if domain in self.kb_result and len(self.kb_result[domain]) > 0: + DA['Booking-Book'] = [["Ref", self.kb_result[domain][0]['Ref']]]
+ # TODO handle booking between multi turn + +
[docs]def check_diff(last_state, state): + # print(state) + user_action = {} + if last_state == {}: + for domain in state['belief_state']: + for slot in state['belief_state'][domain]['book']: + if slot != 'booked' and state['belief_state'][domain]['book'][slot] != '': + if (domain.capitalize() + "-Inform") not in user_action: + user_action[domain.capitalize() + "-Inform"] = [] + if [REF_USR_DA[domain.capitalize()].get(slot, slot), state['belief_state'][domain]['book'][slot]] \ + not in user_action[domain.capitalize() + "-Inform"]: + user_action[domain.capitalize() + "-Inform"].append([REF_USR_DA[domain.capitalize()].get(slot, slot), \ + state['belief_state'][domain]['book'][slot]]) + for slot in state['belief_state'][domain]['semi']: + if state['belief_state'][domain]['semi'][slot] != "": + if (domain.capitalize() + "-Inform") not in user_action: + user_action[domain.capitalize() + "-Inform"] = [] + if [REF_USR_DA[domain.capitalize()].get(slot, slot), state['belief_state'][domain]['semi'][slot]] \ + not in user_action[domain.capitalize() + "-Inform"]: + user_action[domain.capitalize() + "-Inform"].append([REF_USR_DA[domain.capitalize()].get(slot, slot), \ + state['belief_state'][domain]['semi'][slot]]) + for domain in state['request_state']: + for slot in state['request_state'][domain]: + if (domain.capitalize() + "-Request") not in user_action: + user_action[domain.capitalize() + "-Request"] = [] + if [REF_USR_DA[domain].get(slot, slot), '?'] not in user_action[domain.capitalize() + "-Request"]: + user_action[domain.capitalize() + "-Request"].append([REF_USR_DA[domain].get(slot, slot), '?']) + + else: + for domain in state['belief_state']: + for slot in state['belief_state'][domain]['book']: + if slot != 'booked' and state['belief_state'][domain]['book'][slot] != last_state['belief_state'][domain]['book'][slot]: + if (domain.capitalize() + "-Inform") not in user_action: + user_action[domain.capitalize() + "-Inform"] = [] + if [REF_USR_DA[domain.capitalize()].get(slot, slot), + state['belief_state'][domain]['book'][slot]] \ + not in user_action[domain.capitalize() + "-Inform"]: + user_action[domain.capitalize() + "-Inform"].append( + [REF_USR_DA[domain.capitalize()].get(slot, slot), \ + state['belief_state'][domain]['book'][slot]]) + for slot in state['belief_state'][domain]['semi']: + if state['belief_state'][domain]['semi'][slot] != last_state['belief_state'][domain]['semi'][slot] and \ + state['belief_state'][domain]['semi'][slot] != '': + if (domain.capitalize() + "-Inform") not in user_action: + user_action[domain.capitalize() + "-Inform"] = [] + if [REF_USR_DA[domain.capitalize()].get(slot, slot), state['belief_state'][domain]['semi'][slot]] \ + not in user_action[domain.capitalize() + "-Inform"]: + user_action[domain.capitalize() + "-Inform"].append([REF_USR_DA[domain.capitalize()].get(slot, slot), \ + state['belief_state'][domain]['semi'][slot]]) + for domain in state['request_state']: + for slot in state['request_state'][domain]: + if (domain not in last_state['request_state']) or (slot not in last_state['request_state'][domain]): + if (domain.capitalize() + "-Request") not in user_action: + user_action[domain.capitalize() + "-Request"] = [] + if [REF_USR_DA[domain.capitalize()].get(slot, slot), '?'] not in user_action[domain.capitalize() + "-Request"]: + user_action[domain.capitalize() + "-Request"].append([REF_USR_DA[domain.capitalize()].get(slot, slot), '?']) + return user_action
+ + +
[docs]def deduplicate(lst): + i = 0 + while i < len(lst): + if lst[i] in lst[0 : i]: + lst.pop(i) + i -= 1 + i += 1 + return lst
+ +
[docs]def generate_ref_num(length): + """ Generate a ref num for booking. """ + string = "" + while len(string) < length: + string += alphabet[random.randint(0, 999999) % 36] + return string
+ +
[docs]def generate_car(): + """ Generate a car for taxi booking. """ + car_types = ["toyota", "skoda", "bmw", "honda", "ford", "audi", "lexus", "volvo", "volkswagen", "tesla"] + p = random.randint(0, 999999) % len(car_types) + return car_types[p]
+ +
[docs]def fake_state(): + user_action = {'Hotel-Request': [['Name', '?']], 'Train-Inform': [['Day', 'don\'t care']]} + init_belief_state = { + "police": { + "book": { + "booked": [] + }, + "semi": {} + }, + "hotel": { + "book": { + "booked": [], + "people": "", + "day": "", + "stay": "" + }, + "semi": { + "name": "", + "area": "", + "parking": "", + "pricerange": "", + "stars": "", + "internet": "", + "type": "" + } + }, + "attraction": { + "book": { + "booked": [] + }, + "semi": { + "type": "", + "name": "", + "area": "" + } + }, + "restaurant": { + "book": { + "booked": [], + "people": "", + "day": "", + "time": "" + }, + "semi": { + "food": "", + "pricerange": "", + "name": "", + "area": "", + } + }, + "hospital": { + "book": { + "booked": [] + }, + "semi": { + "department": "" + } + }, + "taxi": { + "book": { + "booked": [] + }, + "semi": { + "leaveAt": "", + "destination": "", + "departure": "", + "arriveBy": "" + } + }, + "train": { + "book": { + "booked": [], + "people": "" + }, + "semi": { + "leaveAt": "", + "destination": "", + "day": "", + "arriveBy": "", + "departure": "" + } + } + } + kb_results = [None, None] + kb_results[0] = {'name': 'xxx_train', 'day': 'tuesday', 'dest': 'cam', 'phone': '123-3333', 'area': 'south'} + kb_results[1] = {'name': 'xxx_train', 'day': 'tuesday', 'dest': 'cam', 'phone': '123-3333', 'area': 'north'} + state = {'user_action': user_action, + 'belief_state': init_belief_state, + 'kb_results_dict': kb_results, + 'hotel-request': [['phone']]} + ''' + state = {'user_action': dict(), + 'belief_state: dict(), + 'kb_results_dict': kb_results + } + ''' + return state
+ + +
[docs]def test_init_state(): + user_action = ['general-hello'] + current_slots = dict() + kb_results = [None, None] + kb_results[0] = {'name': 'xxx_train', 'day': 'tuesday', 'dest': 'cam', 'phone': '123-3333', 'area': 'south'} + kb_results[1] = {'name': 'xxx_train', 'day': 'tuesday', 'dest': 'cam', 'phone': '123-3333', 'area': 'north'} + state = {'user_action': user_action, + 'current_slots': current_slots, + 'kb_results_dict': []} + return state
+ +
[docs]def test_run(): + policy = RuleBasedMultiwozBot() + system_act = policy.predict(fake_state()) + print(json.dumps(system_act, indent=4))
+ + +
[docs]class RuleInformBot(SysPolicy): + """ a simple, inform rule bot """ + + def __init__(self): + """ Constructor for RuleInformBot class. """ + SysPolicy.__init__(self) + + self.cur_inform_slot_id = 0 + self.cur_request_slot_id = 0 + self.domains = ['Taxi'] + +
[docs] def init_session(self): + """ + Restore after one session + """ + self.cur_inform_slot_id = 0 + self.cur_request_slot_id = 0
+ +
[docs] def predict(self, state): + + print('state', state.keys()) + for key in state: + print(key, json.dumps(state[key], indent=2)) + + act_slot_response = {} + domain = self.domains[0] + + if self.cur_inform_slot_id < len(REF_SYS_DA[domain]): + key = list(REF_SYS_DA[domain])[self.cur_inform_slot_id] + slot = REF_SYS_DA[domain][key] + + diaact = domain + "-Inform" + val = generate_car() + + act_slot_response[diaact] = [] + act_slot_response[diaact].append([slot, val]) + + self.cur_inform_slot_id += 1 + elif self.cur_request_slot_id < len(REF_SYS_DA[domain]): + key = list(REF_SYS_DA[domain])[self.cur_request_slot_id] + slot = REF_SYS_DA[domain][key] + + diaact = domain + "-Request" + val = "?" + + act_slot_response[diaact] = [] + act_slot_response[diaact].append([slot, val]) + + self.cur_request_slot_id += 1 + else: + act_slot_response['general-hello'] = [] + self.cur_request_slot_id = 0 + self.cur_inform_slot_id = 0 + + return act_slot_response
+ + +if __name__ == '__main__': + test_run() +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/util.html b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/util.html new file mode 100644 index 0000000..cb55c5f --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/util.html @@ -0,0 +1,386 @@ + + + + + + + + + + + convlab.modules.policy.system.multiwoz.util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.system.multiwoz.util
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.system.multiwoz.util

+"""
+Utility package for system policy
+"""
+
+import json
+import os
+
+import numpy as np
+
+from convlab.modules.policy.system.multiwoz.rule_based_multiwoz_bot import generate_car, generate_ref_num
+from convlab.modules.util.multiwoz.dbquery import query
+from convlab.modules.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
+
+DEFAULT_VOCAB_FILE=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))), 
+    "data/multiwoz/da_slot_cnt.json")
+
+
+
[docs]class SkipException(Exception): + def __init__(self): + pass
+ + +
[docs]class ActionVocab(object): + def __init__(self, vocab_path=DEFAULT_VOCAB_FILE, num_actions=500): + # add general actions + self.vocab = [ + {'general-welcome': ['none']}, + {'general-greet': ['none']}, + {'general-bye': ['none']}, + {'general-reqmore': ['none']} + ] + # add single slot actions + for domain in REF_SYS_DA: + for slot in REF_SYS_DA[domain]: + self.vocab.append({domain + '-Inform': [slot]}) + self.vocab.append({domain + '-Request': [slot]}) + # add actions from stats + with open(vocab_path, 'r') as f: + stats = json.load(f) + for action_string in stats: + try: + act_strings = action_string.split(';];') + action_dict = {} + for act_string in act_strings: + if act_string == '': + continue + domain_act, slots = act_string.split('[', 1) + domain, act_type = domain_act.split('-') + if act_type in ['NoOffer', 'OfferBook']: + action_dict[domain_act] = ['none'] + elif act_type in ['Select']: + if slots.startswith('none'): + raise SkipException + action_dict[domain_act] = [slots.split(';')[0]] + else: + action_dict[domain_act] = sorted(slots.split(';')) + if action_dict not in self.vocab: + self.vocab.append(action_dict) + # else: + # print("Duplicate action", str(action_dict)) + except SkipException as e: + print(act_strings) + if len(self.vocab) >= num_actions: + break + print("{} actions are added to vocab".format(len(self.vocab))) + # pprint(self.vocab) + +
[docs] def get_action(self, action_index): + return self.vocab[action_index]
+ + +def _domain_fill(delex_action, state, action, act): + domain, act_type = act.split('-') + constraints = [] + for slot in state['belief_state'][domain.lower()]['semi']: + if state['belief_state'][domain.lower()]['semi'][slot] != "": + constraints.append([slot, state['belief_state'][domain.lower()]['semi'][slot]]) + if act_type in ['NoOffer', 'OfferBook']: # NoOffer['none'], OfferBook['none'] + action[act] = [] + for slot in constraints: + action[act].append([REF_USR_DA[domain].get(slot[0], slot[0]), slot[1]]) + elif act_type in ['Inform', 'Recommend', 'OfferBooked']: # Inform[Slot,...], Recommend[Slot, ...] + kb_result = query(domain.lower(), constraints) + # print("Policy Util") + # print(constraints) + # print(len(kb_result)) + if len(kb_result) == 0: + action[act] = [['none', 'none']] + else: + action[act] = [] + for slot in delex_action[act]: + if slot == 'Choice': + action[act].append([slot, len(kb_result)]) + elif slot == 'Ref': + action[act].append(["Ref", generate_ref_num(8)]) + else: + try: + action[act].append([slot, kb_result[0][REF_SYS_DA[domain].get(slot, slot)]]) + except: + action[act].append([slot, "N/A"]) + if len(action[act]) == 0: + action[act] = [['none', 'none']] + elif act_type in ['Select']: # Select[Slot] + kb_result = query(domain.lower(), constraints) + if len(kb_result) < 2: + action[act] = [['none', 'none']] + else: + slot = delex_action[act][0] + action[act] = [] + action[act].append([slot, kb_result[0][REF_SYS_DA[domain].get(slot, slot)]]) + action[act].append([slot, kb_result[1][REF_SYS_DA[domain].get(slot, slot)]]) + else: + print('Cannot decode:', str(delex_action)) + action[act] = [['none', 'none']] + + +
[docs]def action_decoder(state, action_index, action_vocab): + domains = ['Attraction', 'Hospital', 'Hotel', 'Restaurant', 'Taxi', 'Train', 'Police'] + delex_action = action_vocab.get_action(action_index) + action = {} + + for act in delex_action: + domain, act_type = act.split('-') + if act_type == 'Request': + action[act] = [] + for slot in delex_action[act]: + action[act].append([slot, '?']) + elif act == 'Booking-Book': + action['Booking-Book'] = [["Ref", generate_ref_num(8)]] + elif domain not in domains: + action[act] = [['none', 'none']] + else: + if act == 'Taxi-Inform': + for info_slot in ['leaveAt', 'arriveBy']: + if info_slot in state['belief_state']['taxi']['semi'] and \ + state['belief_state']['taxi']['semi'][info_slot] != "": + car = generate_car() + phone_num = generate_ref_num(11) + action[act] = [] + action[act].append(['Car', car]) + action[act].append(['Phone', phone_num]) + break + else: + action[act] = [['none', 'none']] + elif act in ['Train-Inform', 'Train-NoOffer', 'Train-OfferBook']: + for info_slot in ['departure', 'destination']: + if info_slot not in state['belief_state']['train']['semi'] or \ + state['belief_state']['train']['semi'][info_slot] == "": + action[act] = [['none', 'none']] + break + else: + for info_slot in ['leaveAt', 'arriveBy']: + if info_slot in state['belief_state']['train']['semi'] and \ + state['belief_state']['train']['semi'][info_slot] != "": + _domain_fill(delex_action, state, action, act) + break + else: + action[act] = [['none', 'none']] + elif domain in domains: + _domain_fill(delex_action, state, action, act) + + return action
+ +
[docs]def one_hot(num, domain, domains, vector): + """Return number of available entities for particular domain.""" + number_of_options = 6 + if domain != 'train': + idx = domains.index(domain) + if num == 0: + vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0,0]) + elif num == 1: + vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) + elif num == 2: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) + elif num == 3: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) + elif num == 4: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) + elif num >= 5: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) + else: + idx = domains.index(domain) + if num == 0: + vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0]) + elif num <= 2: + vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) + elif num <= 5: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) + elif num <= 10: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) + elif num <= 40: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) + elif num > 40: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) + + return vector
+ + +if __name__ == '__main__': + pass +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/dataset_reader.html b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/dataset_reader.html new file mode 100644 index 0000000..1a00d08 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/dataset_reader.html @@ -0,0 +1,302 @@ + + + + + + + + + + + convlab.modules.policy.system.multiwoz.vanilla_mle.dataset_reader — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.system.multiwoz.vanilla_mle.dataset_reader
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.system.multiwoz.vanilla_mle.dataset_reader

+import json
+import logging
+import math
+import os
+import zipfile
+from typing import Dict
+
+import numpy as np
+from allennlp.data.dataset_readers.dataset_reader import DatasetReader
+from allennlp.data.fields import ArrayField, LabelField, Field
+from allennlp.data.instance import Instance
+from overrides import overrides
+
+from convlab.lib.file_util import cached_path
+from convlab.modules.dst.multiwoz.rule_dst import RuleDST
+from convlab.modules.policy.system.multiwoz.util import ActionVocab
+from convlab.modules.state_encoder.multiwoz.multiwoz_state_encoder import MultiWozStateEncoder
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+
[docs]@DatasetReader.register("mle_policy") +class MlePolicyDatasetReader(DatasetReader): + """ + Reads instances from a data file: + + Parameters + ---------- + """ + def __init__(self, + num_actions: int, + lazy: bool = False) -> None: + super().__init__(lazy) + self.dst = RuleDST() + self.action_vocab = ActionVocab(num_actions=num_actions) + self.action_list = self.action_vocab.vocab + self.state_encoder = MultiWozStateEncoder() + + @overrides + def _read(self, file_path): + # if `file_path` is a URL, redirect to the cache + file_path = cached_path(file_path) + + if file_path.endswith("zip"): + archive = zipfile.ZipFile(file_path, "r") + data_file = archive.open(os.path.basename(file_path)[:-4]) + else: + data_file = open(file_path, "r") + + logger.info("Reading instances from lines in file at: %s", file_path) + + dialogs = json.load(data_file) + + for dial_name in dialogs: + dialog = dialogs[dial_name]["log"] + self.dst.init_session() + for i, turn in enumerate(dialog): + if i % 2 == 0: # user turn + self.dst.update(user_act=turn["dialog_act"]) + else: # system turn + delex_act = {} + for domain_act in turn["dialog_act"]: + domain, act_type = domain_act.split('-', 1) + if act_type in ['NoOffer', 'OfferBook']: + delex_act[domain_act] = ['none'] + elif act_type in ['Select']: + for sv in turn["dialog_act"][domain_act]: + if sv[0] != "none": + delex_act[domain_act] = [sv[0]] + break + else: + delex_act[domain_act] = [sv[0] for sv in turn["dialog_act"][domain_act]] + state_vector = self.state_encoder.encode(self.dst.state) + action_index = self.find_best_delex_act(delex_act) + + yield self.text_to_instance(state_vector, action_index) + +
[docs] def find_best_delex_act(self, action): + def _score(a1, a2): + score = 0 + for domain_act in a1: + if domain_act not in a2: + score += len(a1[domain_act]) + else: + score += len(set(a1[domain_act]) - set(a2[domain_act])) + return score + + best_p_action_index = -1 + best_p_score = math.inf + best_pn_action_index = -1 + best_pn_score = math.inf + for i, v_action in enumerate(self.action_list): + if v_action == action: + return i + else: + p_score = _score(action, v_action) + n_score = _score(v_action, action) + if p_score > 0 and n_score == 0 and p_score < best_p_score: + best_p_action_index = i + best_p_score = p_score + else: + if p_score + n_score < best_pn_score: + best_pn_action_index = i + best_pn_score = p_score + n_score + if best_p_action_index >= 0: + return best_p_action_index + return best_pn_action_index
+ +
[docs] def text_to_instance(self, state: np.ndarray, action: int = None) -> Instance: # type: ignore + """ + """ + # pylint: disable=arguments-differ + fields: Dict[str, Field] = {} + fields["states"] = ArrayField(state) + if action is not None: + fields["actions"] = LabelField(action, skip_indexing=True) + return Instance(fields)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/evaluate.html b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/evaluate.html new file mode 100644 index 0000000..3414a7e --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/evaluate.html @@ -0,0 +1,305 @@ + + + + + + + + + + + convlab.modules.policy.system.multiwoz.vanilla_mle.evaluate — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.system.multiwoz.vanilla_mle.evaluate
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.system.multiwoz.vanilla_mle.evaluate

+"""
+The ``evaluate`` subcommand can be used to
+evaluate a trained model against a dataset
+and report any metrics calculated by the model.
+"""
+import argparse
+import json
+import logging
+from typing import Dict, Any
+
+from allennlp.common import Params
+from allennlp.common.util import prepare_environment
+from allennlp.data.dataset_readers.dataset_reader import DatasetReader
+from allennlp.data.iterators import DataIterator
+from allennlp.models.archival import load_archive
+from allennlp.training.util import evaluate
+
+from convlab.modules.policy.system.multiwoz.vanilla_mle import dataset_reader, model 
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+argparser = argparse.ArgumentParser(description="Evaluate the specified model + dataset.")
+argparser.add_argument('archive_file', type=str, help='path to an archived trained model')
+
+argparser.add_argument('input_file', type=str, help='path to the file containing the evaluation data')
+
+argparser.add_argument('--output-file', type=str, help='path to output file')
+
+argparser.add_argument('--weights-file',
+                        type=str,
+                        help='a path that overrides which weights file to use')
+
+cuda_device = argparser.add_mutually_exclusive_group(required=False)
+cuda_device.add_argument('--cuda-device',
+                            type=int,
+                            default=-1,
+                            help='id of GPU to use (if any)')
+
+argparser.add_argument('-o', '--overrides',
+                        type=str,
+                        default="",
+                        help='a JSON structure used to override the experiment configuration')
+
+argparser.add_argument('--batch-weight-key',
+                        type=str,
+                        default="",
+                        help='If non-empty, name of metric used to weight the loss on a per-batch basis.')
+
+argparser.add_argument('--extend-vocab',
+                        action='store_true',
+                        default=False,
+                        help='if specified, we will use the instances in your new dataset to '
+                            'extend your vocabulary. If pretrained-file was used to initialize '
+                            'embedding layers, you may also need to pass --embedding-sources-mapping.')
+
+argparser.add_argument('--embedding-sources-mapping',
+                        type=str,
+                        default="",
+                        help='a JSON dict defining mapping from embedding module path to embedding'
+                        'pretrained-file used during training. If not passed, and embedding needs to be '
+                        'extended, we will try to use the original file paths used during training. If '
+                        'they are not available we will use random vectors for embedding extension.')
+
+
+
[docs]def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: + # Disable some of the more verbose logging statements + logging.getLogger('allennlp.common.params').disabled = True + logging.getLogger('allennlp.nn.initializers').disabled = True + logging.getLogger('allennlp.modules.token_embedders.embedding').setLevel(logging.INFO) + + # Load from archive + archive = load_archive(args.archive_file, args.cuda_device, args.overrides, args.weights_file) + config = archive.config + prepare_environment(config) + model = archive.model + model.eval() + + # Load the evaluation data + + # Try to use the validation dataset reader if there is one - otherwise fall back + # to the default dataset_reader used for both training and validation. + validation_dataset_reader_params = config.pop('validation_dataset_reader', None) + if validation_dataset_reader_params is not None: + dataset_reader = DatasetReader.from_params(validation_dataset_reader_params) + else: + dataset_reader = DatasetReader.from_params(config.pop('dataset_reader')) + evaluation_data_path = args.input_file + logger.info("Reading evaluation data from %s", evaluation_data_path) + instances = dataset_reader.read(evaluation_data_path) + + embedding_sources: Dict[str, str] = (json.loads(args.embedding_sources_mapping) if args.embedding_sources_mapping else {}) + if args.extend_vocab: + logger.info("Vocabulary is being extended with test instances.") + model.vocab.extend_from_instances(Params({}), instances=instances) + model.extend_embedder_vocab(embedding_sources) + + iterator_params = config.pop("validation_iterator", None) + if iterator_params is None: + iterator_params = config.pop("iterator") + iterator = DataIterator.from_params(iterator_params) + iterator.index_with(model.vocab) + + metrics = evaluate(model, instances, iterator, args.cuda_device, args.batch_weight_key) + + logger.info("Finished evaluating.") + logger.info("Metrics:") + for key, metric in metrics.items(): + logger.info("%s: %s", key, metric) + + output_file = args.output_file + if output_file: + with open(output_file, "w") as file: + json.dump(metrics, file, indent=4) + return metrics
+ + +if __name__ == "__main__": + args = argparser.parse_args() + evaluate_from_args(args) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/model.html b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/model.html new file mode 100644 index 0000000..d8de29f --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/model.html @@ -0,0 +1,289 @@ + + + + + + + + + + + convlab.modules.policy.system.multiwoz.vanilla_mle.model — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.system.multiwoz.vanilla_mle.model
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.system.multiwoz.vanilla_mle.model

+from typing import Dict, Optional, List, Any
+
+from overrides import overrides
+import numpy as np
+import torch
+from torch.nn.modules.linear import Linear
+
+from allennlp.data import Vocabulary
+from allennlp.modules import FeedForward
+from allennlp.models.model import Model
+from allennlp.nn import InitializerApplicator, RegularizerApplicator
+from allennlp.training.metrics import CategoricalAccuracy
+
+
+
[docs]@Model.register("vanilla_mle_policy") +class VanillaMLE(Model): + """ + The ``VanillaMLE`` makes predictions based on a Softmax over a list of top combinatorial actions. + + Parameters + ---------- + """ + + def __init__(self, vocab: Vocabulary, + input_dim: int, + num_classes: int, + label_namespace: str = "labels", + feedforward: Optional[FeedForward] = None, + dropout: Optional[float] = None, + verbose_metrics: bool = False, + initializer: InitializerApplicator = InitializerApplicator(), + regularizer: Optional[RegularizerApplicator] = None) -> None: + super().__init__(vocab, regularizer) + self.label_namespace = label_namespace + self.input_dim = input_dim + self.num_classes = num_classes + self._verbose_metrics = verbose_metrics + if dropout: + self.dropout = torch.nn.Dropout(dropout) + else: + self.dropout = None + self._feedforward = feedforward + + if self._feedforward is not None: + self.projection_layer = Linear(feedforward.get_output_dim(), self.num_classes) + else: + self.projection_layer = Linear(self.input_dim, self.num_classes) + + self.metrics = { + "accuracy": CategoricalAccuracy(), + "accuracy3": CategoricalAccuracy(top_k=3), + "accuracy5": CategoricalAccuracy(top_k=5) + } + self._loss = torch.nn.CrossEntropyLoss() + + initializer(self) + +
[docs] @overrides + def forward(self, # type: ignore + states: torch.FloatTensor, + actions: torch.LongTensor = None, + metadata: List[Dict[str, Any]] = None, + # pylint: disable=unused-argument + **kwargs) -> Dict[str, torch.Tensor]: + # pylint: disable=arguments-differ + """ + Parameters + ---------- + + Returns + ------- + """ + if self.dropout: + states = self.dropout(states) + + if self._feedforward is not None: + states = self._feedforward(states) + + logits = self.projection_layer(states) + + probs = torch.nn.functional.softmax(logits, dim=-1) + output = {"logits": logits, "probs": probs} + + if actions is not None: + output["loss"] = self._loss(logits, actions) + for metric in self.metrics.values(): + metric(logits, actions) + + return output
+ +
[docs] @overrides + def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Does a simple argmax over the class probabilities. + """ + predictions = output_dict["probs"].detach().cpu().numpy() + argmax_indices = np.argmax(predictions, axis=-1) + output_dict["actions"] = argmax_indices + + return output_dict
+ +
[docs] @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + return {metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items()}
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/policy.html b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/policy.html new file mode 100644 index 0000000..4683877 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/policy.html @@ -0,0 +1,262 @@ + + + + + + + + + + + convlab.modules.policy.system.multiwoz.vanilla_mle.policy — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.system.multiwoz.vanilla_mle.policy
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.system.multiwoz.vanilla_mle.policy

+"""
+"""
+
+import os
+from pprint import pprint
+
+import numpy as np
+from allennlp.common.checks import check_for_gpu
+from allennlp.data import DatasetReader
+from allennlp.models.archival import load_archive
+
+from convlab.lib.file_util import cached_path
+from convlab.modules.dst.multiwoz.dst_util import init_state
+from convlab.modules.policy.system.multiwoz.util import action_decoder
+from convlab.modules.policy.system.policy import SysPolicy
+from convlab.modules.policy.system.multiwoz.vanilla_mle import dataset_reader, model 
+
+DEFAULT_CUDA_DEVICE=-1
+DEFAULT_ARCHIVE_FILE=os.path.join(os.path.dirname(os.path.abspath(__file__)), "models/300/model.tar.gz")
+
+
+
[docs]class VanillaMLEPolicy(SysPolicy): + """Vanilla MLE trained policy.""" + + def __init__(self, + archive_file=DEFAULT_ARCHIVE_FILE, + cuda_device=DEFAULT_CUDA_DEVICE, + model_file=None): + """ Constructor for NLU class. """ + SysPolicy.__init__(self) + + check_for_gpu(cuda_device) + + if not os.path.isfile(archive_file): + if not model_file: + raise Exception("No model for MILU is specified!") + archive_file = cached_path(model_file) + + archive = load_archive(archive_file, + cuda_device=cuda_device) + dataset_reader_params = archive.config["dataset_reader"] + self.dataset_reader = DatasetReader.from_params(dataset_reader_params) + self.action_vocab = self.dataset_reader.action_vocab + self.state_encoder = self.dataset_reader.state_encoder + self.model = archive.model + self.model.eval() + +
[docs] def predict(self, state): + """ + Predict the dialog act of a natural language utterance and apply error model. + Args: + utterance (str): A natural language utterance. + Returns: + output (dict): The dialog act of utterance. + """ + state_vector = self.state_encoder.encode(state) + + instance = self.dataset_reader.text_to_instance(state_vector) + outputs = self.model.forward_on_instance(instance) + dialacts = action_decoder(state, outputs["actions"], self.action_vocab) + if dialacts == {'general-bye': [['none', 'none']]}: + outputs["probs"][outputs["actions"]] = 0 + outputs["actions"] = np.argmax(outputs["probs"]) + dialacts = action_decoder(state, outputs["actions"], self.action_vocab) + + + if state == init_state(): + dialacts = {} + + return dialacts
+ + +if __name__ == "__main__": + from convlab.modules.dst.multiwoz.dst_util import init_state + + policy = VanillaMLEPolicy() + pprint(policy.predict(init_state())) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/train.html b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/train.html new file mode 100644 index 0000000..36ce964 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/system/multiwoz/vanilla_mle/train.html @@ -0,0 +1,382 @@ + + + + + + + + + + + convlab.modules.policy.system.multiwoz.vanilla_mle.train — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.system.multiwoz.vanilla_mle.train
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.system.multiwoz.vanilla_mle.train

+"""
+The ``train`` subcommand can be used to train a model.
+It requires a configuration file and a directory in
+which to write the results.
+"""
+
+import argparse
+import logging
+import os
+
+from allennlp.common import Params
+from allennlp.common.checks import check_for_gpu
+from allennlp.common.util import prepare_environment, prepare_global_logging, cleanup_global_logging, dump_metrics
+from allennlp.models.archival import archive_model, CONFIG_NAME
+from allennlp.models.model import Model, _DEFAULT_WEIGHTS
+from allennlp.training.trainer import Trainer, TrainerPieces
+from allennlp.training.trainer_base import TrainerBase
+from allennlp.training.util import create_serialization_dir, evaluate
+
+from convlab.modules.policy.system.multiwoz.vanilla_mle import dataset_reader, model 
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+argparser = argparse.ArgumentParser(description="Train a model.")
+argparser.add_argument('param_path',
+                        type=str,
+                        help='path to parameter file describing the model to be trained')
+argparser.add_argument('-s', '--serialization-dir',
+                        required=True,
+                        type=str,
+                        help='directory in which to save the model and its logs')
+argparser.add_argument('-r', '--recover',
+                        action='store_true',
+                        default=False,
+                        help='recover training from the state in serialization_dir')
+argparser.add_argument('-f', '--force',
+                        action='store_true',
+                        required=False,
+                        help='overwrite the output directory if it exists')
+argparser.add_argument('-o', '--overrides',
+                        type=str,
+                        default="",
+                        help='a JSON structure used to override the experiment configuration')
+argparser.add_argument('--file-friendly-logging',
+                        action='store_true',
+                        default=False,
+                        help='outputs tqdm status on separate lines and slows tqdm refresh rate')
+
+
+
+
[docs]def train_model_from_args(args: argparse.Namespace): + """ + Just converts from an ``argparse.Namespace`` object to string paths. + """ + train_model_from_file(args.param_path, + args.serialization_dir, + args.overrides, + args.file_friendly_logging, + args.recover, + args.force)
+ + +
[docs]def train_model_from_file(parameter_filename: str, + serialization_dir: str, + overrides: str = "", + file_friendly_logging: bool = False, + recover: bool = False, + force: bool = False) -> Model: + """ + A wrapper around :func:`train_model` which loads the params from a file. + + Parameters + ---------- + parameter_filename : ``str`` + A json parameter file specifying an AllenNLP experiment. + serialization_dir : ``str`` + The directory in which to save results and logs. We just pass this along to + :func:`train_model`. + overrides : ``str`` + A JSON string that we will use to override values in the input parameter file. + file_friendly_logging : ``bool``, optional (default=False) + If ``True``, we make our output more friendly to saved model files. We just pass this + along to :func:`train_model`. + recover : ``bool`, optional (default=False) + If ``True``, we will try to recover a training run from an existing serialization + directory. This is only intended for use when something actually crashed during the middle + of a run. For continuing training a model on new data, see the ``fine-tune`` command. + force : ``bool``, optional (default=False) + If ``True``, we will overwrite the serialization directory if it already exists. + """ + # Load the experiment config from a file and pass it to ``train_model``. + params = Params.from_file(parameter_filename, overrides) + return train_model(params, serialization_dir, file_friendly_logging, recover, force)
+ + +
[docs]def train_model(params: Params, + serialization_dir: str, + file_friendly_logging: bool = False, + recover: bool = False, + force: bool = False) -> Model: + """ + Trains the model specified in the given :class:`Params` object, using the data and training + parameters also specified in that object, and saves the results in ``serialization_dir``. + + Parameters + ---------- + params : ``Params`` + A parameter object specifying an AllenNLP Experiment. + serialization_dir : ``str`` + The directory in which to save results and logs. + file_friendly_logging : ``bool``, optional (default=False) + If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow + down tqdm's output to only once every 10 seconds. + recover : ``bool``, optional (default=False) + If ``True``, we will try to recover a training run from an existing serialization + directory. This is only intended for use when something actually crashed during the middle + of a run. For continuing training a model on new data, see the ``fine-tune`` command. + force : ``bool``, optional (default=False) + If ``True``, we will overwrite the serialization directory if it already exists. + + Returns + ------- + best_model: ``Model`` + The model with the best epoch weights. + """ + prepare_environment(params) + create_serialization_dir(params, serialization_dir, recover, force) + stdout_handler = prepare_global_logging(serialization_dir, file_friendly_logging) + + cuda_device = params.params.get('trainer').get('cuda_device', -1) + check_for_gpu(cuda_device) + + params.to_file(os.path.join(serialization_dir, CONFIG_NAME)) + + evaluate_on_test = params.pop_bool("evaluate_on_test", False) + + trainer_type = params.get("trainer", {}).get("type", "default") + + if trainer_type == "default": + # Special logic to instantiate backward-compatible trainer. + pieces = TrainerPieces.from_params(params, serialization_dir, recover) # pylint: disable=no-member + trainer = Trainer.from_params( + model=pieces.model, + serialization_dir=serialization_dir, + iterator=pieces.iterator, + train_data=pieces.train_dataset, + validation_data=pieces.validation_dataset, + params=pieces.params, + validation_iterator=pieces.validation_iterator) + evaluation_iterator = pieces.validation_iterator or pieces.iterator + evaluation_dataset = pieces.test_dataset + + else: + trainer = TrainerBase.from_params(params, serialization_dir, recover) + # TODO(joelgrus): handle evaluation in the general case + evaluation_iterator = evaluation_dataset = None + + params.assert_empty('base train command') + + try: + metrics = trainer.train() + except KeyboardInterrupt: + # if we have completed an epoch, try to create a model archive. + if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): + logging.info("Training interrupted by the user. Attempting to create " + "a model archive using the current best epoch weights.") + archive_model(serialization_dir, files_to_archive=params.files_to_archive) + raise + + # Evaluate + if evaluation_dataset and evaluate_on_test: + logger.info("The model will be evaluated using the best epoch weights.") + test_metrics = evaluate(trainer.model, evaluation_dataset, evaluation_iterator, + cuda_device=trainer._cuda_devices[0], # pylint: disable=protected-access, + # TODO(brendanr): Pass in an arg following Joel's trainer refactor. + batch_weight_key="") + + for key, value in test_metrics.items(): + metrics["test_" + key] = value + + elif evaluation_dataset: + logger.info("To evaluate on the test set after training, pass the " + "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.") + + cleanup_global_logging(stdout_handler) + + # Now tar up results + archive_model(serialization_dir, files_to_archive=params.files_to_archive) + dump_metrics(os.path.join(serialization_dir, "metrics.json"), metrics, log=True) + + # We count on the trainer to have the model with best weights + return trainer.model
+ +if __name__ == "__main__": + args = argparser.parse_args() + train_model_from_args(args) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/system/policy.html b/docs/build/html/_modules/convlab/modules/policy/system/policy.html new file mode 100644 index 0000000..3639400 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/system/policy.html @@ -0,0 +1,212 @@ + + + + + + + + + + + convlab.modules.policy.system.policy — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.system.policy
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.system.policy

+"""
+The policy base class for system bot.
+"""
+
+
+
[docs]class SysPolicy: + """Base class for system policy model.""" + + def __init__(self): + """ Constructor for SysPolicy class. """ + pass + +
[docs] def predict(self, state): + """ + Predict the system action (dialog act) given state. + Args: + state (dict): Dialog state. For more details about the each field of the dialog state, please refer to + the init_state method in convlab/dst/dst_util.py + Returns: + action (dict): The dialog act of the current turn system response, which is then passed to NLG module to + generate a NL utterance. + """ + pass
+ +
[docs] def init_session(self): + """Init the SysPolicy module to start a new session.""" + pass
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/user/multiwoz/policy_agenda_multiwoz.html b/docs/build/html/_modules/convlab/modules/policy/user/multiwoz/policy_agenda_multiwoz.html new file mode 100644 index 0000000..834b7f1 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/user/multiwoz/policy_agenda_multiwoz.html @@ -0,0 +1,1031 @@ + + + + + + + + + + + convlab.modules.policy.user.multiwoz.policy_agenda_multiwoz — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.user.multiwoz.policy_agenda_multiwoz
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.user.multiwoz.policy_agenda_multiwoz

+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+"""
+
+__time__ = '2019/1/31 10:24'
+
+import copy
+import json
+import os
+import random
+import re
+
+from convlab.lib import logger
+from convlab.modules.policy.user.policy import UserPolicy
+from convlab.modules.usr.multiwoz.goal_generator import GoalGenerator
+from convlab.modules.util.multiwoz.multiwoz_slot_trans import REF_USR_DA, REF_SYS_DA
+
+logger = logger.get_logger(__name__)
+
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'don\'t care'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+DEF_VAL_BOOKED = 'yes'  # for booked
+DEF_VAL_NOBOOK = 'no'  # for booked
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK]
+
+# import reflect table
+REF_USR_DA_M = copy.deepcopy(REF_USR_DA)
+REF_SYS_DA_M = {}
+for dom, ref_slots in REF_SYS_DA.items():
+    dom = dom.lower()
+    REF_SYS_DA_M[dom] = {}
+    for slot_a, slot_b in ref_slots.items():
+        REF_SYS_DA_M[dom][slot_a.lower()] = slot_b
+    REF_SYS_DA_M[dom]['none'] = None
+
+# def book slot
+BOOK_SLOT = ['people', 'day', 'stay', 'time']
+
+
[docs]class UserPolicyAgendaMultiWoz(UserPolicy): + """ The rule-based user policy model by agenda. Derived from the UserPolicy class """ + + # load stand value + stand_value_dict = json.load(open(os.path.join(os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))), + 'data/value_set.json'))) + + def __init__(self, max_goal_num=100, seed=2019): + """ + Constructor for User_Policy_Agenda class. + """ + self.max_turn = 40 + self.max_initiative = 4 + + self.goal_generator = GoalGenerator(corpus_path='data/multiwoz/annotated_user_da_with_span_full.json') + + self.__turn = 0 + self.goal = None + self.agenda = None + + random.seed(seed) + self.goal_seeds = [random.randint(1,1e7) for i in range(max_goal_num)] + + #UserPolicy.__init__(self, act_types, slots, slot_dict) + UserPolicy.__init__(self) + +
[docs] def init_session(self): + """ Build new Goal and Agenda for next session """ + self.__turn = 0 + if len(self.goal_seeds)>1: + self.goal = Goal(self.goal_generator, self.goal_seeds[0]) + self.goal_seeds = self.goal_seeds[1:] + else: + self.goal = Goal(self.goal_generator) + self.domain_goals = self.goal.domain_goals + self.agenda = Agenda(self.goal)
+ +
[docs] def predict(self, state, sys_action): + """ + Predict an user act based on state and preorder system action. + Args: + state (tuple): Dialog state. + sys_action (tuple): Preorder system action.s + Returns: + action (tuple): User act. + session_over (boolean): True to terminate session, otherwise session continues. + reward (float): Reward given by user. + """ + self.__turn += 2 + + # At the beginning of a dialog when there is no NLU. + if sys_action == "null": + sys_action = {} + + if self.__turn > self.max_turn: + self.agenda.close_session() + else: + sys_action = self._transform_sysact_in(sys_action) + self.agenda.update(sys_action, self.goal) + if self.goal.task_complete(): + self.agenda.close_session() + + # A -> A' + user_action + # action = self.agenda.get_action(random.randint(1, self.max_initiative)) + action = self.agenda.get_action(self.max_initiative) + + # Is there any action to say? + session_over = self.agenda.is_empty() + + # reward + reward = self._reward() + + # transform to DA + action = self._transform_usract_out(action) + + return action, session_over, reward
+ + def _reward(self): + """ + Calculate reward based on task completion + Returns: + reward (float): Reward given by user. + """ + if self.goal.task_complete(): + reward = 2.0 * self.max_turn + elif self.agenda.is_empty(): + reward = -1.0 * self.max_turn + else: + reward = -1.0 + return reward + + @classmethod + def _transform_usract_out(cls, action): + new_action = {} + for act in action.keys(): + if '-' in act: + if 'general' not in act: + (dom, intent) = act.split('-') + new_act = dom.capitalize() + '-' + intent.capitalize() + new_action[new_act] = [] + for pairs in action[act]: + slot = REF_USR_DA_M[dom.capitalize()].get(pairs[0], None) + if slot is not None: + new_action[new_act].append([slot, pairs[1]]) + #new_action[new_act] = [[REF_USR_DA_M[dom.capitalize()].get(pairs[0], pairs[0]), pairs[1]] for pairs in action[act]] + else: + new_action[act] = action[act] + else: + pass + return new_action + + @classmethod + def _transform_sysact_in(cls, action): + new_action = {} + if not isinstance(action, dict): + logger.warning(f'illegal da: {action}') + return new_action + + for act in action.keys(): + if not isinstance(act, str) or '-' not in act: + logger.warning(f'illegal act: {act}') + continue + + if 'general' not in act: + (dom, intent) = act.lower().split('-') + if dom in REF_SYS_DA_M.keys(): + new_list = [] + for pairs in action[act]: + if (not isinstance(pairs, list) and not isinstance(pairs, tuple)) or\ + (len(pairs) < 2) or\ + (not isinstance(pairs[0], str) or (not isinstance(pairs[1], str) and not isinstance(pairs[1], int))): + logger.warning(f'illegal pairs: {pairs}') + continue + + if REF_SYS_DA_M[dom].get(pairs[0].lower(), None) is not None: + new_list.append([REF_SYS_DA_M[dom][pairs[0].lower()], cls._normalize_value(dom, intent, REF_SYS_DA_M[dom][pairs[0].lower()], pairs[1])]) + + if len(new_list) > 0: + new_action[act.lower()] = new_list + else: + new_action[act.lower()] = action[act] + + return new_action + + @classmethod + def _normalize_value(cls, domain, intent, slot, value): + if intent == 'request': + return DEF_VAL_UNK + + if domain not in cls.stand_value_dict.keys(): + return value + + if slot not in cls.stand_value_dict[domain]: + return value + + value_list = cls.stand_value_dict[domain][slot] + low_value_list = [item.lower() for item in value_list] + value_list = list(set(value_list).union(set(low_value_list))) + if value not in value_list: + normalized_v = simple_fuzzy_match(value_list, value) + if normalized_v is not None: + return normalized_v + # try some transformations + cand_values = transform_value(value) + for cv in cand_values: + _nv = simple_fuzzy_match(value_list, cv) + if _nv is not None: + return _nv + if check_if_time(value): + return value + + logger.debug('Value not found in standard value set: [%s] (slot: %s domain: %s)' % (value, slot, domain)) + return value
+ +
[docs]def transform_value(value): + cand_list = [] + # a 's -> a's + if " 's" in value: + cand_list.append(value.replace(" 's", "'s")) + # a - b -> a-b + if " - " in value: + cand_list.append(value.replace(" - ", "-")) + return cand_list
+ +
[docs]def simple_fuzzy_match(value_list, value): + # check contain relation + v0 = ' '.join(value.split()) + v0N = ''.join(value.split()) + for val in value_list: + v1 = ' '.join(val.split()) + if v0 in v1 or v1 in v0 or v0N in v1 or v1 in v0N: + return v1 + value = value.lower() + v0 = ' '.join(value.split()) + v0N = ''.join(value.split()) + for val in value_list: + v1 = ' '.join(val.split()) + if v0 in v1 or v1 in v0 or v0N in v1 or v1 in v0N: + return v1 + return None
+ +
[docs]def check_if_time(value): + value = value.strip() + match = re.search(r"(\d{1,2}:\d{1,2})", value) + if match is None: + return False + groups = match.groups() + if len(groups) <= 0: + return False + return True
+ + +
[docs]class Goal(object): + """ User Goal Model Class. """ + + def __init__(self, goal_generator: GoalGenerator, seed=None): + """ + create new Goal by random + Args: + goal_generator (GoalGenerator): Goal Gernerator. + """ + self.domain_goals = goal_generator.get_user_goal(seed) + + self.domains = list(self.domain_goals['domain_ordering']) + del self.domain_goals['domain_ordering'] + + for domain in self.domains: + if 'reqt' in self.domain_goals[domain].keys(): + self.domain_goals[domain]['reqt'] = {slot: DEF_VAL_UNK for slot in self.domain_goals[domain]['reqt']} + + if 'book' in self.domain_goals[domain].keys(): + self.domain_goals[domain]['booked'] = DEF_VAL_UNK + +
[docs] def task_complete(self): + """ + Check that all requests have been met + Returns: + (boolean): True to accomplish. + """ + for domain in self.domains: + if 'reqt' in self.domain_goals[domain]: + reqt_vals = self.domain_goals[domain]['reqt'].values() + for val in reqt_vals: + if val in NOT_SURE_VALS: + return False + + if 'booked' in self.domain_goals[domain]: + if self.domain_goals[domain]['booked'] in NOT_SURE_VALS: + return False + return True
+ +
[docs] def next_domain_incomplete(self): + # request + for domain in self.domains: + # reqt + if 'reqt' in self.domain_goals[domain]: + requests = self.domain_goals[domain]['reqt'] + unknow_reqts = [key for (key, val) in requests.items() if val in NOT_SURE_VALS] + if len(unknow_reqts) > 0: + return domain, 'reqt', ['name'] if 'name' in unknow_reqts else unknow_reqts + + # book + if 'booked' in self.domain_goals[domain]: + if self.domain_goals[domain]['booked'] in NOT_SURE_VALS: + return domain, 'book', \ + self.domain_goals[domain]['fail_book'] if 'fail_book' in self.domain_goals[domain].keys() else self.domain_goals[domain]['book'] + + return None, None, None
+ + def __str__(self): + return '-----Goal-----\n' + \ + json.dumps(self.domain_goals, indent=4) + \ + '\n-----Goal-----'
+ + +
[docs]class Agenda(object): + def __init__(self, goal: Goal): + """ + Build a new agenda from goal + Args: + goal (Goal): User goal. + """ + + def random_sample(data, minimum=0, maximum=1000): + return random.sample(data, random.randint(min(len(data), minimum), min(len(data), maximum))) + + self.CLOSE_ACT = 'general-bye' + self.HELLO_ACT = 'general-greet' + self.__cur_push_num = 0 + + self.__stack = [] + + # there is a 'bye' action at the bottom of the stack + self.__push(self.CLOSE_ACT) + + for idx in range(len(goal.domains) - 1, -1, -1): + domain = goal.domains[idx] + + # inform + if 'fail_info' in goal.domain_goals[domain]: + for slot in random_sample(goal.domain_goals[domain]['fail_info'].keys(), + len(goal.domain_goals[domain]['fail_info'])): + self.__push(domain + '-inform', slot, goal.domain_goals[domain]['fail_info'][slot]) + elif 'info' in goal.domain_goals[domain]: + for slot in random_sample(goal.domain_goals[domain]['info'].keys(), + len(goal.domain_goals[domain]['info'])): + self.__push(domain + '-inform', slot, goal.domain_goals[domain]['info'][slot]) + + self.cur_domain = None + +
[docs] def update(self, sys_action, goal: Goal): + """ + update Goal by current agent action and current goal. { A' + G" + sys_action -> A" } + Args: + sys_action (tuple): Preorder system action.s + goal (Goal): User Goal + """ + self.__cur_push_num = 0 + self._update_current_domain(sys_action, goal) + + for diaact in sys_action.keys(): + slot_vals = sys_action[diaact] + if 'nooffer' in diaact: + if self.update_domain(diaact, slot_vals, goal): + return + elif 'nobook' in diaact: + if self.update_booking(diaact, slot_vals, goal): + return + + for diaact in sys_action.keys(): + if 'nooffer' in diaact or 'nobook' in diaact: + continue + + slot_vals = sys_action[diaact] + if 'booking' in diaact: + if self.update_booking(diaact, slot_vals, goal): + return + elif 'general' in diaact: + if self.update_general(diaact, slot_vals, goal): + return + else: + if self.update_domain(diaact, slot_vals, goal): + return + + unk_dom, unk_type, data = goal.next_domain_incomplete() + if unk_dom is not None: + if unk_type == 'reqt' and not self._check_reqt_info(unk_dom) and not self._check_reqt(unk_dom): + for slot in data: + self._push_item(unk_dom + '-request', slot, DEF_VAL_UNK) + elif unk_type == 'book' and not self._check_reqt_info(unk_dom) and not self._check_book_info(unk_dom): + for (slot, val) in data.items(): + self._push_item(unk_dom + '-inform', slot, val)
+ +
[docs] def update_booking(self, diaact, slot_vals, goal: Goal): + """ + Handel Book-XXX + :param diaact: Dial-Act + :param slot_vals: slot value pairs + :param goal: Goal + :return: True:user want to close the session. False:session is continue + """ + _, intent = diaact.split('-') + domain = self.cur_domain + + if domain not in goal.domains: + return False + + g_reqt = goal.domain_goals[domain].get('reqt', dict({})) + g_info = goal.domain_goals[domain].get('info', dict({})) + g_fail_info = goal.domain_goals[domain].get('fail_info', dict({})) + g_book = goal.domain_goals[domain].get('book', dict({})) + g_fail_book = goal.domain_goals[domain].get('fail_book', dict({})) + + if intent in ['book', 'inform']: + info_right = True + for [slot, value] in slot_vals: + if slot == 'time': + if domain in ['train', 'restaurant']: + slot = 'duration' if domain == 'train' else 'time' + else: + logger.warning(f'illegal booking slot: {slot}, domain: {domain}') + continue + + if slot in g_reqt: + if not self._check_reqt_info(domain): + self._remove_item(domain + '-request', slot) + if value in NOT_SURE_VALS: + g_reqt[slot] = '\"' + value + '\"' + else: + g_reqt[slot] = value + + elif slot in g_fail_info and value != g_fail_info[slot]: + self._push_item(domain + '-inform', slot, g_fail_info[slot]) + info_right = False + elif len(g_fail_info) <= 0 and slot in g_info and value != g_info[slot]: + self._push_item(domain + '-inform', slot, g_info[slot]) + info_right = False + + elif slot in g_fail_book and value != g_fail_book[slot]: + self._push_item(domain + '-inform', slot, g_fail_book[slot]) + info_right = False + elif len(g_fail_book) <= 0 and slot in g_book and value != g_book[slot]: + self._push_item(domain + '-inform', slot, g_book[slot]) + info_right = False + + else: + pass + + if intent == 'book' and info_right: + # booked ok + if 'booked' in goal.domain_goals[domain]: + goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED + self._push_item('general-thank') + + elif intent in ['nobook']: + if len(g_fail_book) > 0: + # Discard fail_book data and update the book data to the stack + for slot in g_book.keys(): + if (slot not in g_fail_book) or (slot in g_fail_book and g_fail_book[slot] != g_book[slot]): + self._push_item(domain + '-inform', slot, g_book[slot]) + + # change fail_info name + goal.domain_goals[domain]['fail_book_fail'] = goal.domain_goals[domain].pop('fail_book') + elif 'booked' in goal.domain_goals[domain].keys(): + self.close_session() + return True + + elif intent in ['request']: + for [slot, _] in slot_vals: + if slot == 'time': + if domain in ['train', 'restaurant']: + slot = 'duration' if domain == 'train' else 'time' + else: + logger.warning('illegal booking slot: %s, slot: %s domain' % (slot, domain)) + continue + + if slot in g_reqt: + pass + elif slot in g_fail_info: + self._push_item(domain + '-inform', slot, g_fail_info[slot]) + elif len(g_fail_info) <= 0 and slot in g_info: + self._push_item(domain + '-inform', slot, g_info[slot]) + + elif slot in g_fail_book: + self._push_item(domain + '-inform', slot, g_fail_book[slot]) + elif len(g_fail_book) <= 0 and slot in g_book: + self._push_item(domain + '-inform', slot, g_book[slot]) + + else: + + if domain == 'taxi' and (slot == 'destination' or slot == 'departure'): + places = [dom for dom in goal.domains[: goal.domains.index('taxi')] if + 'address' in goal.domain_goals[dom]['reqt']] + + if len(places) >= 1 and slot == 'destination' and \ + goal.domain_goals[places[-1]]['reqt']['address'] not in NOT_SURE_VALS: + self._push_item(domain + '-inform', slot, goal.domain_goals[places[-1]]['reqt']['address']) + + elif len(places) >= 2 and slot == 'departure' and \ + goal.domain_goals[places[-2]]['reqt']['address'] not in NOT_SURE_VALS: + self._push_item(domain + '-inform', slot, goal.domain_goals[places[-2]]['reqt']['address']) + + elif random.random() < 0.5: + self._push_item(domain + '-inform', slot, DEF_VAL_DNC) + + elif random.random() < 0.5: + self._push_item(domain + '-inform', slot, DEF_VAL_DNC) + + return False
+ +
[docs] def update_domain(self, diaact, slot_vals, goal: Goal): + """ + Handel Domain-XXX + :param diaact: Dial-Act + :param slot_vals: slot value pairs + :param goal: Goal + :return: True:user want to close the session. False:session is continue + """ + domain, intent = diaact.split('-') + + if domain not in goal.domains: + return False + + g_reqt = goal.domain_goals[domain].get('reqt', dict({})) + g_info = goal.domain_goals[domain].get('info', dict({})) + g_fail_info = goal.domain_goals[domain].get('fail_info', dict({})) + g_book = goal.domain_goals[domain].get('book', dict({})) + g_fail_book = goal.domain_goals[domain].get('fail_book', dict({})) + + if intent in ['inform', 'recommend', 'offerbook', 'offerbooked']: + info_right = True + for [slot, value] in slot_vals: + if slot in g_reqt: + if not self._check_reqt_info(domain): + self._remove_item(domain + '-request', slot) + if value in NOT_SURE_VALS: + g_reqt[slot] = '\"' + value + '\"' + else: + g_reqt[slot] = value + + elif slot in g_fail_info and value != g_fail_info[slot]: + self._push_item(domain + '-inform', slot, g_fail_info[slot]) + info_right = False + elif len(g_fail_info) <= 0 and slot in g_info and value != g_info[slot]: + self._push_item(domain + '-inform', slot, g_info[slot]) + info_right = False + + elif slot in g_fail_book and value != g_fail_book[slot]: + self._push_item(domain + '-inform', slot, g_fail_book[slot]) + info_right = False + elif len(g_fail_book) <= 0 and slot in g_book and value != g_book[slot]: + self._push_item(domain + '-inform', slot, g_book[slot]) + info_right = False + + else: + pass + + if intent == 'offerbooked' and info_right: + # booked ok + if 'booked' in goal.domain_goals[domain]: + goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED + self._push_item('general-thank') + + elif intent in ['request']: + for [slot, _] in slot_vals: + if slot in g_reqt: + pass + elif slot in g_fail_info: + self._push_item(domain + '-inform', slot, g_fail_info[slot]) + elif len(g_fail_info) <= 0 and slot in g_info: + self._push_item(domain + '-inform', slot, g_info[slot]) + + elif slot in g_fail_book: + self._push_item(domain + '-inform', slot, g_fail_book[slot]) + elif len(g_fail_book) <= 0 and slot in g_book: + self._push_item(domain + '-inform', slot, g_book[slot]) + + else: + + if domain == 'taxi' and (slot == 'destination' or slot == 'departure'): + places = [dom for dom in goal.domains[: goal.domains.index('taxi')] if + 'address' in goal.domain_goals[dom]['reqt']] + + if len(places) >= 1 and slot == 'destination' and \ + goal.domain_goals[places[-1]]['reqt']['address'] not in NOT_SURE_VALS: + self._push_item(domain + '-inform', slot, goal.domain_goals[places[-1]]['reqt']['address']) + + elif len(places) >= 2 and slot == 'departure' and \ + goal.domain_goals[places[-2]]['reqt']['address'] not in NOT_SURE_VALS: + self._push_item(domain + '-inform', slot, goal.domain_goals[places[-2]]['reqt']['address']) + + elif random.random() < 0.5: + self._push_item(domain + '-inform', slot, DEF_VAL_DNC) + + elif random.random() < 0.5: + self._push_item(domain + '-inform', slot, DEF_VAL_DNC) + + elif intent in ['nooffer']: + if len(g_fail_info) > 0: + # update info data to the stack + for slot in g_info.keys(): + if (slot not in g_fail_info) or (slot in g_fail_info and g_fail_info[slot] != g_info[slot]): + self._push_item(domain + '-inform', slot, g_info[slot]) + + # change fail_info name + goal.domain_goals[domain]['fail_info_fail'] = goal.domain_goals[domain].pop('fail_info') + elif len(g_reqt.keys()) > 0: + self.close_session() + return True + + elif intent in ['select']: + # delete Choice + slot_vals = [[slot, val] for [slot, val] in slot_vals if slot != 'choice'] + + if len(slot_vals) > 0: + slot = slot_vals[0][0] + + if slot in g_fail_info: + self._push_item(domain + '-inform', slot, g_fail_info[slot]) + elif len(g_fail_info) <= 0 and slot in g_info: + self._push_item(domain + '-inform', slot, g_info[slot]) + + elif slot in g_fail_book: + self._push_item(domain + '-inform', slot, g_fail_book[slot]) + elif len(g_fail_book) <= 0 and slot in g_book: + self._push_item(domain + '-inform', slot, g_book[slot]) + + else: + if not self._check_reqt_info(domain): + [slot, value] = random.choice(slot_vals) + self._push_item(domain + '-inform', slot, value) + + if slot in g_reqt: + self._remove_item(domain + '-request', slot) + g_reqt[slot] = value + + return False
+ +
[docs] def update_general(self, diaact, slot_vals, goal: Goal): + domain, intent = diaact.split('-') + + if intent == 'bye': + # self.close_session() + # return True + pass + elif intent == 'greet': + pass + elif intent == 'reqmore': + pass + elif intent == 'welcome': + pass + + return False
+ +
[docs] def close_session(self): + """ Clear up all actions """ + self.__stack = [] + self.__push(self.CLOSE_ACT)
+ +
[docs] def get_action(self, initiative=1): + """ + get multiple acts based on initiative + Args: + initiative (int): number of slots , just for 'inform' + Returns: + action (dict): user diaact + """ + diaacts, slots, values = self.__pop(initiative) + action = {} + for (diaact, slot, value) in zip(diaacts, slots, values): + if diaact not in action.keys(): + action[diaact] = [] + action[diaact].append([slot, value]) + + return action
+ +
[docs] def is_empty(self): + """ + Is the agenda already empty + Returns: + (boolean): True for empty, False for not. + """ + return len(self.__stack) <= 0
+ + def _update_current_domain(self, sys_action, goal: Goal): + for diaact in sys_action.keys(): + domain, _ = diaact.split('-') + if domain in goal.domains: + self.cur_domain = domain + + def _remove_item(self, diaact, slot=DEF_VAL_UNK): + for idx in range(len(self.__stack)): + if 'general' in diaact: + if self.__stack[idx]['diaact'] == diaact: + self.__stack.remove(self.__stack[idx]) + break + else: + if self.__stack[idx]['diaact'] == diaact and self.__stack[idx]['slot'] == slot: + self.__stack.remove(self.__stack[idx]) + break + + def _push_item(self, diaact, slot=DEF_VAL_NUL, value=DEF_VAL_NUL): + self._remove_item(diaact, slot) + self.__push(diaact, slot, value) + self.__cur_push_num += 1 + + def _check_item(self, diaact, slot=None): + for idx in range(len(self.__stack)): + if slot is None: + if self.__stack[idx]['diaact'] == diaact: + return True + else: + if self.__stack[idx]['diaact'] == diaact and self.__stack[idx]['slot'] == slot: + return True + return False + + def _check_reqt(self, domain): + for idx in range(len(self.__stack)): + if self.__stack[idx]['diaact'] == domain + '-request': + return True + return False + + def _check_reqt_info(self, domain): + for idx in range(len(self.__stack)): + if self.__stack[idx]['diaact'] == domain + '-inform' and self.__stack[idx]['slot'] not in BOOK_SLOT: + return True + return False + + def _check_book_info(self, domain): + for idx in range(len(self.__stack)): + if self.__stack[idx]['diaact'] == domain + '-inform' and self.__stack[idx]['slot'] in BOOK_SLOT: + return True + return False + + def __check_next_diaact_slot(self): + if len(self.__stack) > 0: + return self.__stack[-1]['diaact'], self.__stack[-1]['slot'] + return None, None + + def __check_next_diaact(self): + if len(self.__stack) > 0: + return self.__stack[-1]['diaact'] + return None + + def __push(self, diaact, slot=DEF_VAL_NUL, value=DEF_VAL_NUL): + self.__stack.append({'diaact': diaact, 'slot': slot, 'value': value}) + + def __pop(self, initiative=1): + diaacts = [] + slots = [] + values = [] + + p_diaact, p_slot = self.__check_next_diaact_slot() + if p_diaact.split('-')[1] == 'inform' and p_slot in BOOK_SLOT: + for _ in range(10 if self.__cur_push_num == 0 else self.__cur_push_num): + try: + item = self.__stack.pop(-1) + diaacts.append(item['diaact']) + slots.append(item['slot']) + values.append(item['value']) + + cur_diaact = item['diaact'] + + next_diaact, next_slot = self.__check_next_diaact_slot() + if next_diaact is None or \ + next_diaact != cur_diaact or \ + next_diaact.split('-')[1] != 'inform' or next_slot not in BOOK_SLOT: + break + except: + break + else: + for _ in range(initiative if self.__cur_push_num == 0 else self.__cur_push_num): + try: + item = self.__stack.pop(-1) + diaacts.append(item['diaact']) + slots.append(item['slot']) + values.append(item['value']) + + cur_diaact = item['diaact'] + + next_diaact = self.__check_next_diaact() + if next_diaact is None or \ + next_diaact != cur_diaact or \ + (cur_diaact.split('-')[1] == 'request' and item['slot'] == 'name'): + break + except: + break + + return diaacts, slots, values + + def __str__(self): + text = '\n-----agenda-----\n' + text += '<stack top>\n' + for item in reversed(self.__stack): + text += str(item) + '\n' + text += '<stack btm>\n' + text += '-----agenda-----\n' + return text
+ + +
[docs]def test(): + user_simulator = UserPolicyAgendaMultiWoz() + user_simulator.init_session() + + test_turn(user_simulator, {'Train-NoOffer': [['Day', 'saturday'], ['Dest', 'place'], ['Depart', 'place']]}) + test_turn(user_simulator, {'Hotel-NoOffer': [['stars', '3'], ['internet', 'yes'], ['type', 'guest house']], 'Hotel-Request': [['stars', '?']]}) + test_turn(user_simulator, {"Hotel-Inform": [["Type", "qqq"], ["Parking", "no"], ["Internet", "yes"]]}) + test_turn(user_simulator, {"Hotel-Inform": [["Type", "qqq"], ["Parking", "no"]]}) + test_turn(user_simulator, {"Booking-Request": [["Day", "123"], ["Time", "no"]]}) + test_turn(user_simulator, {"Hotel-Request": [["Addr", "?"]], "Hotel-Inform": [["Internet", "yes"]]}) + test_turn(user_simulator, {"Hotel-Request": [["Type", "?"], ["Parking", "?"]]}) + test_turn(user_simulator, {"Hotel-Nooffer": [["Stars", "3"]], "Hotel-Request": [["Parking", "?"]]}) + test_turn(user_simulator, {"Hotel-Select": [["Area", "aa"], ["Area", "bb"], ["Area", "cc"], ['Choice', 3]]}) + test_turn(user_simulator, {"Hotel-Offerbooked": [["Ref", "12345"]]})
+ +
[docs]def test_turn(user_simulator, sys_action): + print('input:', sys_action) + action, session_over, reward = user_simulator.predict(None, sys_action) + print('----------------------------------') + print('sys_action :' + str(sys_action)) + print('user_action:' + str(action)) + print('over :' + str(session_over)) + print('reward :' + str(reward)) + print(user_simulator.goal) + print(user_simulator.agenda)
+ + +
[docs]def test_with_system(): + from convlab.policy.system.rule_based_multiwoz_bot import RuleBasedMultiwozBot, fake_state + user_simulator = UserPolicyAgendaMultiWoz() + user_simulator.init_session() + state = fake_state() + system_agent = RuleBasedMultiwozBot(None, None, None) + sys_action = system_agent.predict(state) + action, session_over, reward = user_simulator.predict(None, sys_action) + print("Sys:") + print(json.dumps(sys_action, indent=4)) + print("User:") + print(json.dumps(action, indent=4))
+ + +if __name__ == '__main__': + test() + # test_with_system() +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/user/multiwoz/policy_vhus.html b/docs/build/html/_modules/convlab/modules/policy/user/multiwoz/policy_vhus.html new file mode 100644 index 0000000..ab5797f --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/user/multiwoz/policy_vhus.html @@ -0,0 +1,215 @@ + + + + + + + + + + + convlab.modules.policy.user.multiwoz.policy_vhus — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.user.multiwoz.policy_vhus
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.user.multiwoz.policy_vhus

+# -*- coding: utf-8 -*-
+import os
+
+from convlab.modules.policy.user.policy import UserPolicy
+from convlab.modules.usr.multiwoz.vhus_usr.user import UserNeural
+
+
+
[docs]class Goal(): + def __init__(self, goal): + self.goal = goal + +
[docs] def task_complete(self): + return 0
+ +
[docs]class UserPolicyVHUS(UserPolicy): + + def __init__(self): + self.user = UserNeural() + path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), + 'usr/multiwoz/vhus_user/model/best') + self.user.load(path) + +
[docs] def init_session(self): + self.user.init_session() + self.domain_goals = self.user.goal + self.goal = Goal(self.domain_goals)
+ +
[docs] def predict(self, state, sys_action): + usr_action, terminal = self.user.predict(state, sys_action) + return usr_action, terminal, 0
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/policy/user/policy.html b/docs/build/html/_modules/convlab/modules/policy/user/policy.html new file mode 100644 index 0000000..3d4b0f8 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/policy/user/policy.html @@ -0,0 +1,214 @@ + + + + + + + + + + + convlab.modules.policy.user.policy — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.policy.user.policy
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.policy.user.policy

+"""
+The policy base class for user bot.
+"""
+
+
+
[docs]class UserPolicy: + """Base model for user policy model.""" + def __init__(self): + """ Constructor for UserPolicy class. """ + pass + +
[docs] def predict(self, state, sys_action): + """ + Predict an user act based on state and preorder system action. + Args: + state (tuple): Dialog state. + sys_action (tuple): Preorder system action.s + Returns: + action (tuple): User act. + session_over (boolean): True to terminate session, otherwise session continues. + reward (float): Reward given by the user. + """ + pass
+ +
[docs] def init_session(self): + """ + Restore after one session + """ + pass
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/state_encoder/multiwoz/multiwoz_state_encoder.html b/docs/build/html/_modules/convlab/modules/state_encoder/multiwoz/multiwoz_state_encoder.html new file mode 100644 index 0000000..a60acf7 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/state_encoder/multiwoz/multiwoz_state_encoder.html @@ -0,0 +1,373 @@ + + + + + + + + + + + convlab.modules.state_encoder.multiwoz.multiwoz_state_encoder — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.state_encoder.multiwoz.multiwoz_state_encoder
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.state_encoder.multiwoz.multiwoz_state_encoder

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+
+from convlab.modules.util.multiwoz.dbquery import query
+from convlab.modules.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
+
+
+
[docs]class MultiWozStateEncoder(object): + def __init__(self): + pass + +
[docs] def encode(self, state): + db_vector = self.get_db_state(state['belief_state']) + book_vector = self.get_book_state(state['belief_state']) + info_vector = self.get_info_state(state['belief_state']) + request_vector = self.get_request_state(state['request_state']) + user_act_vector = self.get_user_act_state(state['user_action']) + history_vector = self.get_history_state(state['history']) + + return np.concatenate((db_vector, book_vector, info_vector, request_vector, user_act_vector, history_vector))
+ +
[docs] def get_history_state(self, history): + history_vector = [] + + user_act = None + repeat_count = 0 + user_act_repeat_vector = [0] * 5 + for turn in reversed(history): + if user_act == None: + user_act = turn[1] + elif user_act == turn[1]: + repeat_count += 1 + else: + break + user_act_repeat_vector[min(4, repeat_count)] = 1 + history_vector += user_act_repeat_vector + + return history_vector
+ +
[docs] def get_user_act_state(self, user_act): + user_act_vector = [] + + for domain in REF_SYS_DA: + for slot in REF_SYS_DA[domain]: + for act_type in ['Inform', 'Request', 'Booking']: + domain_act = domain + '-' + act_type + if domain_act in user_act and slot in [sv[0] for sv in user_act[domain_act]]: + user_act_vector.append(1) + # print(domain, act_type, slot) + else: + user_act_vector.append(0) + + return np.array(user_act_vector)
+ +
[docs] def get_request_state(self, request_state): + domains = ['taxi', 'restaurant', 'hospital', 'hotel', 'attraction', 'train', 'police'] + request_vector = [] + + for domain in domains: + domain_vector = [0] * (len(REF_USR_DA[domain.capitalize()]) + 1) + if domain in request_state: + for slot in request_state[domain]: + if slot == 'ref': + domain_vector[-1] = 1 + else: + # print("request: {} {}".format(domain.capitalize(), slot)) + domain_vector[list(REF_USR_DA[domain.capitalize()].keys()).index(slot)] = 1 + # print("request:", slot) + request_vector.extend(domain_vector) + + return np.array(request_vector)
+ +
[docs] def get_info_state(self, belief_state): + """Based on the mturk annotations we form multi-domain belief state""" + domains = ['taxi', 'restaurant', 'hospital', 'hotel', 'attraction', 'train', 'police'] + info_vector = [] + + for domain in domains: + domain_active = False + + booking = [] + for slot in sorted(belief_state[domain]['book'].keys()): + if slot == 'booked': + if belief_state[domain]['book']['booked'] != []: + booking.append(1) + else: + booking.append(0) + else: + if belief_state[domain]['book'][slot] != "": + booking.append(1) + else: + booking.append(0) + info_vector += booking + + for slot in belief_state[domain]['semi']: + slot_enc = [0, 0, 0] # not mentioned, dontcare, filled + if belief_state[domain]['semi'][slot] in ['', 'not mentioned']: + slot_enc[0] = 1 + elif belief_state[domain]['semi'][slot] == 'dont care' or belief_state[domain]['semi'][slot] == 'dontcare' or belief_state[domain]['semi'][slot] == "don't care": + slot_enc[1] = 1 + domain_active = True + elif belief_state[domain]['semi'][slot]: + slot_enc[2] = 1 + domain_active = True + info_vector += slot_enc + + # quasi domain-tracker + if domain_active: + info_vector += [1] + else: + info_vector += [0] + + assert len(info_vector) == 93, f'info_vector {len(info_vector)}' + return np.array(info_vector)
+ +
[docs] def get_book_state(self, belief_state): + """Add information about availability of the booking option.""" + # Booking pointer + rest_vec = np.array([1, 0]) + if "book" in belief_state['restaurant']: + if "booked" in belief_state['restaurant']['book']: + if belief_state['restaurant']['book']["booked"]: + if "reference" in belief_state['restaurant']['book']["booked"][0]: + rest_vec = np.array([0, 1]) + + hotel_vec = np.array([1, 0]) + if "book" in belief_state['hotel']: + if "booked" in belief_state['hotel']['book']: + if belief_state['hotel']['book']["booked"]: + if "reference" in belief_state['hotel']['book']["booked"][0]: + hotel_vec = np.array([0, 1]) + + train_vec = np.array([1, 0]) + if "book" in belief_state['train']: + if "booked" in belief_state['train']['book']: + if belief_state['train']['book']["booked"]: + if "reference" in belief_state['train']['book']["booked"][0]: + train_vec = np.array([0, 1]) + + return np.concatenate((rest_vec, hotel_vec, train_vec))
+ +
[docs] def get_db_state(self, belief_state): + domains = ['restaurant', 'hotel', 'attraction', 'train'] + db_vector = np.zeros(6 * len(domains)) + num_entities = {} + for domain in domains: + entities = query(domain, belief_state[domain]['semi'].items()) + db_vector = self.one_hot(len(entities), domain, domains, db_vector) + + return db_vector
+ +
[docs] def one_hot(self, num, domain, domains, vector): + """Return number of available entities for particular domain.""" + number_of_options = 6 + if domain != 'train': + idx = domains.index(domain) + if num == 0: + vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0,0]) + elif num == 1: + vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) + elif num == 2: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) + elif num == 3: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) + elif num == 4: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) + elif num >= 5: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) + else: + idx = domains.index(domain) + if num == 0: + vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0]) + elif num <= 2: + vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) + elif num <= 5: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) + elif num <= 10: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) + elif num <= 40: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) + elif num > 40: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) + + return vector
+ + +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/usr/multiwoz/goal_generator.html b/docs/build/html/_modules/convlab/modules/usr/multiwoz/goal_generator.html new file mode 100644 index 0000000..6b01cae --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/usr/multiwoz/goal_generator.html @@ -0,0 +1,832 @@ + + + + + + + + + + + convlab.modules.usr.multiwoz.goal_generator — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.usr.multiwoz.goal_generator
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.usr.multiwoz.goal_generator

+"""
+"""
+
+import json
+import os
+import pickle
+import random
+from collections import Counter
+from copy import deepcopy
+
+import numpy as np
+
+from convlab.modules.util.multiwoz import dbquery
+
+domains = {'attraction', 'hotel', 'restaurant', 'train', 'taxi', 'hospital', 'police'}
+days = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']
+domain_keywords = {
+    'restaurant': 'place to dine',
+    'train': 'train',
+    'hotel': 'place to stay',
+    'attraction': 'places to go',
+    'police': 'help',
+    'taxi': 'taxi',
+    'hospital': 'hospital'
+}
+request_slot_string_map = {
+    'phone': 'phone number',
+    'pricerange': 'price range',
+    'duration': 'travel time',
+    'arriveBy': 'arrival time',
+    'leaveAt': 'departure time',
+    'trainID': 'train ID'
+}
+templates = {
+    'intro': 'You are looking for information in Cambridge.',
+    'restaurant': {
+        'intro': 'You are looking forward to trying local restaurants.',
+        'request': 'Once you find a restaurnat, make sure you get {}.',
+        'area': 'The restaurant should be in the {}.',
+        'food': 'The restaurant should serve {} food.',
+        'name': 'You are looking for a particular restaurant. Its name is called {}.',
+        'pricerange': 'The restaurant should be in the {} price range.',
+        'book': 'Once you find the restaurant you want to book a table {}.',
+        'fail_info food': 'If there is no such restaurant, how about one that serves {} food.',
+        'fail_info area': 'If there is no such restaurant, how about one in the {} area.',
+        'fail_info pricerange': 'If there is no such restaurant, how about one in the {} price range.',
+        'fail_book time': 'If the booking fails how about {}.',
+        'fail_book day': 'If the booking fails how about {}.'
+    },
+    'hotel': {
+        'intro': 'You are looking for a place to stay.',
+        'request': 'Once you find a hotel, make sure you get {}.',
+        'stars': 'The hotel should have a star of {}.',
+        'area': 'The hotel should be in the {}.',
+        'type': 'The hotel should be in the type of {}.',
+        'pricerange': 'The hotel should be in the {} price range.',
+        'name': 'You are looking for a particular hotel. Its name is called {}.',
+        'internet yes': 'The hotel should include free wifi.',
+        'internet no': 'The hotel does not need to include free wifi.',
+        'parking yes': 'The hotel should include free parking.',
+        'parking no': 'The hotel does not need to include free parking.',
+        'book': 'Once you find the hotel you want to book it {}.',
+        'fail_info type': 'If there is no such hotel, how about one that is in the type of {}.',
+        'fail_info area': 'If there is no such hotel, how about one that is in the {} area.',
+        'fail_info stars': 'If there is no such hotel, how about one that has a star of {}.',
+        'fail_info pricerange': 'If there is no such hotel, how about one that is in the {} price range.',
+        'fail_info parking yes': 'If there is no such hotel, how about one that has free parking.',
+        'fail_info parking no': 'If there is no such hotel, how about one that does not has free parking.',
+        'fail_info internet yes': 'If there is no such hotel, how about one that has free wifi.',
+        'fail_info internet no': 'If there is no such hotel, how about one that does not has free wifi.',
+        'fail_book stay': 'If the booking fails how about {} nights.',
+        'fail_book day': 'If the booking fails how about {}.'
+    },
+    'attraction': {
+        'intro': 'You are excited about seeing local tourist attractions.',
+        'request': 'Once you find an attraction, make sure you get {}.',
+        'area': 'The attraction should be in the {}.',
+        'type': 'The attraction should be in the type of {}.',
+        'name': 'You are looking for a particular attraction. Its name is called {}.',
+        'fail_info type': 'If there is no such attraction, how about one that is in the type of {}.',
+        'fail_info area': 'If there is no such attraction, how about one in the {} area.'
+    },
+    'taxi': {
+        'intro': 'You are also looking for a taxi.',
+        'commute': 'You also want to book a taxi to commute between the two places.',
+        'restaurant': 'You want to make sure it arrives the restaurant by the booked time.',
+        'request': 'Once you find a taxi, make sure you get {}.',
+        'departure': 'The taxi should depart from {}.',
+        'destination': 'The taxi should go to {}.',
+        'leaveAt': 'The taxi should leave after {}.',
+        'arriveBy': 'The taxi should arrive by {}.'
+    },
+    'train': {
+        'intro': 'You are also looking for a train.',
+        'request': 'Once you find a train, make sure you get {}.',
+        'departure': 'The train should depart from {}.',
+        'destination': 'The train should go to {}.',
+        'day': 'The train should leave on {}.',
+        'leaveAt': 'The train should leave after {}.',
+        'arriveBy': 'The train should arrive by {}.',
+        'book': 'Once you find the train you want to make a booking {}.'
+    },
+    'police': {
+        'intro': 'You were robbed and are looking for help.',
+        'request': 'Make sure you get {}.'
+    },
+    'hospital': {
+        'intro': 'You got injured and are looking for a hospital nearby',
+        'request': 'Make sure you get {}.',
+        'department': 'The hospital should have the {} department.'
+    }
+}
+
+pro_correction = {
+    # "info": 0.2,
+    "info": 0.0,
+    # "reqt": 0.2,
+    "reqt": 0.0,
+    # "book": 0.2
+    "book": 0.0
+}
+
+
+
[docs]def null_boldify(content): + return content
+ +
[docs]def do_boldify(content): + return '<b>' + content + '</b>'
+ +
[docs]def nomial_sample(counter: Counter): + return list(counter.keys())[np.argmax(np.random.multinomial(1, list(counter.values())))]
+ +
[docs]class GoalGenerator: + """User goal generator.""" + + def __init__(self, + goal_model_path=os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), + 'data/multiwoz/goal/goal_model.pkl'), + corpus_path=None, + boldify=False): + """ + Args: + goal_model_path: path to a goal model + corpus_path: path to a dialog corpus to build a goal model + """ + self.goal_model_path = goal_model_path + self.corpus_path = corpus_path + self.boldify = do_boldify if boldify else null_boldify + if os.path.exists(self.goal_model_path): + self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist = pickle.load( + open(self.goal_model_path, 'rb')) + print('Loading goal model is done') + else: + self._build_goal_model() + print('Building goal model is done') + + # remove some slot + del self.ind_slot_dist['police']['reqt']['postcode'] + del self.ind_slot_value_dist['police']['reqt']['postcode'] + del self.ind_slot_dist['hospital']['reqt']['postcode'] + del self.ind_slot_value_dist['hospital']['reqt']['postcode'] + del self.ind_slot_dist['hospital']['reqt']['address'] + del self.ind_slot_value_dist['hospital']['reqt']['address'] + + def _build_goal_model(self): + dialogs = json.load(open(self.corpus_path)) + + # domain ordering + def _get_dialog_domains(dialog): + return list(filter(lambda x: x in domains and len(dialog['goal'][x]) > 0, dialog['goal'])) + + domain_orderings = [] + for d in dialogs: + d_domains = _get_dialog_domains(dialogs[d]) + first_index = [] + for domain in d_domains: + message = [dialogs[d]['goal']['message']] if isinstance(dialogs[d]['goal']['message'], str) else \ + dialogs[d]['goal']['message'] + for i, m in enumerate(message): + if domain_keywords[domain].lower() in m.lower() or domain.lower() in m.lower(): + first_index.append(i) + break + domain_orderings.append(tuple(map(lambda x: x[1], sorted(zip(first_index, d_domains), key=lambda x: x[0])))) + domain_ordering_cnt = Counter(domain_orderings) + self.domain_ordering_dist = deepcopy(domain_ordering_cnt) + for order in domain_ordering_cnt.keys(): + self.domain_ordering_dist[order] = domain_ordering_cnt[order] / sum(domain_ordering_cnt.values()) + + # independent goal slot distribution + ind_slot_value_cnt = dict([(domain, {}) for domain in domains]) + domain_cnt = Counter() + book_cnt = Counter() + + for d in dialogs: + for domain in domains: + if dialogs[d]['goal'][domain] != {}: + domain_cnt[domain] += 1 + if 'info' in dialogs[d]['goal'][domain]: + for slot in dialogs[d]['goal'][domain]['info']: + if 'invalid' in slot: + continue + if 'info' not in ind_slot_value_cnt[domain]: + ind_slot_value_cnt[domain]['info'] = {} + if slot not in ind_slot_value_cnt[domain]['info']: + ind_slot_value_cnt[domain]['info'][slot] = Counter() + if 'care' in dialogs[d]['goal'][domain]['info'][slot]: + continue + ind_slot_value_cnt[domain]['info'][slot][dialogs[d]['goal'][domain]['info'][slot]] += 1 + if 'reqt' in dialogs[d]['goal'][domain]: + for slot in dialogs[d]['goal'][domain]['reqt']: + if 'reqt' not in ind_slot_value_cnt[domain]: + ind_slot_value_cnt[domain]['reqt'] = Counter() + ind_slot_value_cnt[domain]['reqt'][slot] += 1 + if 'book' in dialogs[d]['goal'][domain]: + book_cnt[domain] += 1 + for slot in dialogs[d]['goal'][domain]['book']: + if 'invalid' in slot: + continue + if 'book' not in ind_slot_value_cnt[domain]: + ind_slot_value_cnt[domain]['book'] = {} + if slot not in ind_slot_value_cnt[domain]['book']: + ind_slot_value_cnt[domain]['book'][slot] = Counter() + if 'care' in dialogs[d]['goal'][domain]['book'][slot]: + continue + ind_slot_value_cnt[domain]['book'][slot][dialogs[d]['goal'][domain]['book'][slot]] += 1 + + self.ind_slot_value_dist = deepcopy(ind_slot_value_cnt) + self.ind_slot_dist = dict([(domain, {}) for domain in domains]) + self.book_dist = {} + for domain in domains: + if 'info' in ind_slot_value_cnt[domain]: + for slot in ind_slot_value_cnt[domain]['info']: + if 'info' not in self.ind_slot_dist[domain]: + self.ind_slot_dist[domain]['info'] = {} + if slot not in self.ind_slot_dist[domain]['info']: + self.ind_slot_dist[domain]['info'][slot] = {} + self.ind_slot_dist[domain]['info'][slot] = sum(ind_slot_value_cnt[domain]['info'][slot].values()) / \ + domain_cnt[domain] + slot_total = sum(ind_slot_value_cnt[domain]['info'][slot].values()) + for val in self.ind_slot_value_dist[domain]['info'][slot]: + self.ind_slot_value_dist[domain]['info'][slot][val] = ind_slot_value_cnt[domain]['info'][slot][ + val] / slot_total + if 'reqt' in ind_slot_value_cnt[domain]: + for slot in ind_slot_value_cnt[domain]['reqt']: + if 'reqt' not in self.ind_slot_dist[domain]: + self.ind_slot_dist[domain]['reqt'] = {} + self.ind_slot_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / domain_cnt[ + domain] + self.ind_slot_value_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / \ + domain_cnt[domain] + if 'book' in ind_slot_value_cnt[domain]: + for slot in ind_slot_value_cnt[domain]['book']: + if 'book' not in self.ind_slot_dist[domain]: + self.ind_slot_dist[domain]['book'] = {} + if slot not in self.ind_slot_dist[domain]['book']: + self.ind_slot_dist[domain]['book'][slot] = {} + self.ind_slot_dist[domain]['book'][slot] = sum(ind_slot_value_cnt[domain]['book'][slot].values()) / \ + domain_cnt[domain] + slot_total = sum(ind_slot_value_cnt[domain]['book'][slot].values()) + for val in self.ind_slot_value_dist[domain]['book'][slot]: + self.ind_slot_value_dist[domain]['book'][slot][val] = ind_slot_value_cnt[domain]['book'][slot][ + val] / slot_total + self.book_dist[domain] = book_cnt[domain] / len(dialogs) + + pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist), + open(self.goal_model_path, 'wb')) + + def _get_domain_goal(self, domain): + cnt_slot = self.ind_slot_dist[domain] + cnt_slot_value = self.ind_slot_value_dist[domain] + pro_book = self.book_dist[domain] + + while True: + # domain_goal = defaultdict(lambda: {}) + # domain_goal = {'info': {}, 'fail_info': {}, 'reqt': {}, 'book': {}, 'fail_book': {}} + domain_goal = {'info': {}} + # inform + if 'info' in cnt_slot: + for slot in cnt_slot['info']: + if random.random() < cnt_slot['info'][slot] + pro_correction['info']: + domain_goal['info'][slot] = nomial_sample(cnt_slot_value['info'][slot]) + + if domain in ['hotel', 'restaurant', 'attraction'] and 'name' in domain_goal['info'] and len( + domain_goal['info']) > 1: + if random.random() < cnt_slot['info']['name']: + domain_goal['info'] = {'name': domain_goal['info']['name']} + else: + del domain_goal['info']['name'] + + if domain in ['taxi', 'train'] and 'arriveBy' in domain_goal['info'] and 'leaveAt' in domain_goal[ + 'info']: + if random.random() < ( + cnt_slot['info']['leaveAt'] / (cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])): + del domain_goal['info']['arriveBy'] + else: + del domain_goal['info']['leaveAt'] + + if domain in ['taxi', 'train'] and 'arriveBy' not in domain_goal['info'] and 'leaveAt' not in \ + domain_goal['info']: + if random.random() < (cnt_slot['info']['arriveBy'] / ( + cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])): + domain_goal['info']['arriveBy'] = nomial_sample(cnt_slot_value['info']['arriveBy']) + else: + domain_goal['info']['leaveAt'] = nomial_sample(cnt_slot_value['info']['leaveAt']) + + if domain in ['taxi', 'train'] and 'departure' not in domain_goal['info']: + domain_goal['info']['departure'] = nomial_sample(cnt_slot_value['info']['departure']) + + if domain in ['taxi', 'train'] and 'destination' not in domain_goal['info']: + domain_goal['info']['destination'] = nomial_sample(cnt_slot_value['info']['destination']) + + if domain in ['taxi', 'train'] and \ + 'departure' in domain_goal['info'] and \ + 'destination' in domain_goal['info'] and \ + domain_goal['info']['departure'] == domain_goal['info']['destination']: + if random.random() < (cnt_slot['info']['departure'] / ( + cnt_slot['info']['departure'] + cnt_slot['info']['destination'])): + domain_goal['info']['departure'] = nomial_sample(cnt_slot_value['info']['departure']) + else: + domain_goal['info']['destination'] = nomial_sample(cnt_slot_value['info']['destination']) + if domain_goal['info'] == {}: + continue + # request + if 'reqt' in cnt_slot: + reqt = [slot for slot in cnt_slot['reqt'] + if random.random() < cnt_slot['reqt'][slot] + pro_correction['reqt'] and slot not in + domain_goal['info']] + if len(reqt) > 0: + domain_goal['reqt'] = reqt + + # book + if 'book' in cnt_slot and random.random() < pro_book + pro_correction['book']: + if 'book' not in domain_goal: + domain_goal['book'] = {} + + for slot in cnt_slot['book']: + if random.random() < cnt_slot['book'][slot] + pro_correction['book']: + domain_goal['book'][slot] = nomial_sample(cnt_slot_value['book'][slot]) + + # makes sure that there are all necessary slots for booking + if domain == 'restaurant' and 'time' not in domain_goal['book']: + domain_goal['book']['time'] = nomial_sample(cnt_slot_value['book']['time']) + + if domain == 'hotel' and 'stay' not in domain_goal['book']: + domain_goal['book']['stay'] = nomial_sample(cnt_slot_value['book']['stay']) + + if domain in ['hotel', 'restaurant'] and 'day' not in domain_goal['book']: + domain_goal['book']['day'] = nomial_sample(cnt_slot_value['book']['day']) + + if domain in ['hotel', 'restaurant'] and 'people' not in domain_goal['book']: + domain_goal['book']['people'] = nomial_sample(cnt_slot_value['book']['people']) + + if domain == 'train' and len(domain_goal['book']) <= 0: + domain_goal['book']['people'] = nomial_sample(cnt_slot_value['book']['people']) + + # fail_book + if 'book' in domain_goal and random.random() < 0.5: + if domain == 'hotel': + domain_goal['fail_book'] = deepcopy(domain_goal['book']) + if 'stay' in domain_goal['book'] and random.random() < 0.5: + # increase hotel-stay + domain_goal['fail_book']['stay'] = str(int(domain_goal['book']['stay']) + 1) + elif 'day' in domain_goal['book']: + # push back hotel-day by a day + domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7] + + elif domain == 'restaurant': + domain_goal['fail_book'] = deepcopy(domain_goal['book']) + if 'time' in domain_goal['book'] and random.random() < 0.5: + hour, minute = domain_goal['book']['time'].split(':') + domain_goal['fail_book']['time'] = str((int(hour) + 1) % 24) + ':' + minute + elif 'day' in domain_goal['book']: + if random.random() < 0.5: + domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7] + else: + domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) + 1) % 7] + + # fail_info + if 'info' in domain_goal and len(dbquery.query(domain, domain_goal['info'].items())) == 0: + num_trial = 0 + while num_trial < 100: + adjusted_info = self._adjust_info(domain, domain_goal['info']) + if len(dbquery.query(domain, adjusted_info.items())) > 0: + if domain == 'train': + domain_goal['info'] = adjusted_info + else: + domain_goal['fail_info'] = domain_goal['info'] + domain_goal['info'] = adjusted_info + + break + num_trial += 1 + + if num_trial >= 100: + continue + + # at least there is one request and book + if 'reqt' in domain_goal or 'book' in domain_goal: + break + + return domain_goal + +
[docs] def get_user_goal(self, seed=None): + if seed is not None: + random.seed(seed) + np.random.seed(seed) + domain_ordering = () + while len(domain_ordering) <= 0: + domain_ordering = nomial_sample(self.domain_ordering_dist) + # domain_ordering = ('restaurant',) + + user_goal = {dom: self._get_domain_goal(dom) for dom in domain_ordering} + assert len(user_goal.keys()) > 0 + + # using taxi to communte between places, removing destination and departure. + if 'taxi' in domain_ordering: + places = [dom for dom in domain_ordering[: domain_ordering.index('taxi')] if 'address' in self.ind_slot_dist[dom]['reqt'].keys()] + if len(places) >= 1: + del user_goal['taxi']['info']['destination'] + user_goal[places[-1]]['reqt'] = list(set(user_goal[places[-1]].get('reqt', [])).union({'address'})) + if places[-1] == 'restaurant' and 'book' in user_goal['restaurant']: + user_goal['taxi']['info']['arriveBy'] = user_goal['restaurant']['book']['time'] + if 'leaveAt' in user_goal['taxi']['info']: + del user_goal['taxi']['info']['leaveAt'] + if len(places) >= 2: + del user_goal['taxi']['info']['departure'] + user_goal[places[-2]]['reqt'] = list(set(user_goal[places[-2]].get('reqt', [])).union({'address'})) + + # match area of attraction and restaurant + if 'restaurant' in domain_ordering and \ + 'attraction' in domain_ordering and \ + 'fail_info' not in user_goal['restaurant'] and \ + domain_ordering.index('restaurant') > domain_ordering.index('attraction') and \ + 'area' in user_goal['restaurant']['info'] and 'area' in user_goal['attraction']['info']: + adjusted_restaurant_goal = deepcopy(user_goal['restaurant']['info']) + adjusted_restaurant_goal['area'] = user_goal['attraction']['info']['area'] + if len(dbquery.query('restaurant', adjusted_restaurant_goal.items())) > 0 and random.random() < 0.5: + user_goal['restaurant']['info']['area'] = user_goal['attraction']['info']['area'] + + # match day and people of restaurant and hotel + if 'restaurant' in domain_ordering and 'hotel' in domain_ordering and \ + 'book' in user_goal['restaurant'] and 'book' in user_goal['hotel']: + if random.random() < 0.5: + user_goal['restaurant']['book']['people'] = user_goal['hotel']['book']['people'] + if 'fail_book' in user_goal['restaurant']: + user_goal['restaurant']['fail_book']['people'] = user_goal['hotel']['book']['people'] + if random.random() < 1.0: + user_goal['restaurant']['book']['day'] = user_goal['hotel']['book']['day'] + if 'fail_book' in user_goal['restaurant']: + user_goal['restaurant']['fail_book']['day'] = user_goal['hotel']['book']['day'] + if user_goal['restaurant']['book']['day'] == user_goal['restaurant']['fail_book']['day'] and \ + user_goal['restaurant']['book']['time'] == user_goal['restaurant']['fail_book']['time'] and \ + user_goal['restaurant']['book']['people'] == user_goal['restaurant']['fail_book']['people']: + del user_goal['restaurant']['fail_book'] + + # match day and people of hotel and train + if 'hotel' in domain_ordering and 'train' in domain_ordering and \ + 'book' in user_goal['hotel'] and 'info' in user_goal['train']: + if user_goal['train']['info']['destination'] == 'cambridge' and \ + 'day' in user_goal['hotel']['book']: + user_goal['train']['info']['day'] = user_goal['hotel']['book']['day'] + elif user_goal['train']['info']['departure'] == 'cambridge' and \ + 'day' in user_goal['hotel']['book'] and 'stay' in user_goal['hotel']['book']: + user_goal['train']['info']['day'] = days[ + (days.index(user_goal['hotel']['book']['day']) + int( + user_goal['hotel']['book']['stay'])) % 7] + # In case, we have no query results with adjusted train goal, we simply drop the train goal. + if len(dbquery.query('train', user_goal['train']['info'].items())) == 0: + del user_goal['train'] + domain_ordering = tuple(list(domain_ordering).remove('train')) + + user_goal['domain_ordering'] = domain_ordering + + return user_goal
+ + def _adjust_info(self, domain, info): + # adjust one of the slots of the info + adjusted_info = deepcopy(info) + slot = random.choice(list(info.keys())) + adjusted_info[slot] = random.choice(list(self.ind_slot_value_dist[domain]['info'][slot].keys())) + return adjusted_info + +
[docs] def build_message(self, user_goal, boldify=null_boldify): + message = [] + state = deepcopy(user_goal) + + for dom in user_goal['domain_ordering']: + dom_msg = [] + state = deepcopy(user_goal[dom]) + num_acts_in_unit = 0 + + if not (dom == 'taxi' and len(state['info']) == 1): + # intro + m = [templates[dom]['intro']] + + # info + def fill_info_template(user_goal, domain, slot, info): + if slot != 'area' or not ('restaurant' in user_goal and + 'attraction' in user_goal and + info in user_goal['restaurant'].keys() and + info in user_goal['attraction'].keys() and + 'area' in user_goal['restaurant'][info] and + 'area' in user_goal['attraction'][info] and + user_goal['restaurant'][info]['area'] == user_goal['attraction'][info]['area']): + return templates[domain][slot].format(self.boldify(user_goal[domain][info][slot])) + else: + restaurant_index = user_goal['domain_ordering'].index('restaurant') + attraction_index = user_goal['domain_ordering'].index('attraction') + if restaurant_index > attraction_index and domain == 'restaurant': + return templates[domain][slot].format(self.boldify('same area as the attraction')) + elif attraction_index > restaurant_index and domain == 'attraction': + return templates[domain][slot].format(self.boldify('same area as the restaurant')) + return templates[domain][slot].format(self.boldify(user_goal[domain][info][slot])) + + info = 'info' + if 'fail_info' in user_goal[dom]: + info = 'fail_info' + if dom == 'taxi' and len(state[info]) == 1: + taxi_index = user_goal['domain_ordering'].index('taxi') + places = [dom for dom in user_goal['domain_ordering'][: taxi_index] if + dom in ['attraction', 'hotel', 'restaurant']] + if len(places) >= 2: + random.shuffle(places) + m.append(templates['taxi']['commute']) + if 'arriveBy' in state[info]: + m.append('The taxi should arrive at the {} from the {} by {}.'.format(self.boldify(places[0]), + self.boldify(places[1]), + self.boldify(state[info]['arriveBy']))) + elif 'leaveAt' in state[info]: + m.append('The taxi should leave from the {} to the {} after {}.'.format(self.boldify(places[0]), + self.boldify(places[1]), + self.boldify(state[info]['leaveAt']))) + message.append(' '.join(m)) + else: + while len(state[info]) > 0: + num_acts = random.randint(1, min(len(state[info]), 3)) + slots = random.sample(list(state[info].keys()), num_acts) + sents = [fill_info_template(user_goal, dom, slot, info) for slot in slots if slot not in ['parking', 'internet']] + if 'parking' in slots: + sents.append(templates[dom]['parking ' + state[info]['parking']]) + if 'internet' in slots: + sents.append(templates[dom]['internet ' + state[info]['internet']]) + m.extend(sents) + message.append(' '.join(m)) + m = [] + for slot in slots: + del state[info][slot] + + # fail_info + if 'fail_info' in user_goal[dom]: + # if 'fail_info' in user_goal[dom]: + adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1], + zip(user_goal[dom]['info'].items(), user_goal[dom]['fail_info'].items())))[0][0][0] + if adjusted_slot in ['internet', 'parking']: + message.append(templates[dom]['fail_info ' + adjusted_slot + ' ' + user_goal[dom]['info'][adjusted_slot]]) + else: + message.append(templates[dom]['fail_info ' + adjusted_slot].format(self.boldify(user_goal[dom]['info'][adjusted_slot]))) + + # reqt + if 'reqt' in state: + slot_strings = [] + for slot in state['reqt']: + if slot in ['internet', 'parking', 'food']: + continue + slot_strings.append(slot if slot not in request_slot_string_map else request_slot_string_map[slot]) + if len(slot_strings) > 0: + message.append(templates[dom]['request'].format(self.boldify(', '.join(slot_strings)))) + if 'internet' in state['reqt']: + message.append('Make sure to ask if the hotel includes free wifi.') + if 'parking' in state['reqt']: + message.append('Make sure to ask if the hotel includes free parking.') + if 'food' in state['reqt']: + message.append('Make sure to ask about what food it serves.') + + def get_same_people_domain(user_goal, domain, slot): + if slot not in ['day', 'people']: + return None + domain_index = user_goal['domain_ordering'].index(domain) + previous_domains = user_goal['domain_ordering'][:domain_index] + for prev in previous_domains: + if prev in ['restaurant', 'hotel', 'train'] and 'book' in user_goal[prev] and \ + slot in user_goal[prev]['book'] and user_goal[prev]['book'][slot] == \ + user_goal[domain]['book'][slot]: + return prev + return None + + # book + book = 'book' + if 'fail_book' in user_goal[dom]: + book = 'fail_book' + if 'book' in state: + slot_strings = [] + for slot in ['people', 'time', 'day', 'stay']: + if slot in state[book]: + if slot == 'people': + same_people_domain = get_same_people_domain(user_goal, dom, slot) + if same_people_domain is None: + slot_strings.append('for {} people'.format(self.boldify(state[book][slot]))) + else: + slot_strings.append(self.boldify( + 'for the same group of people as the {} booking'.format(same_people_domain))) + elif slot == 'time': + slot_strings.append('at {}'.format(self.boldify(state[book][slot]))) + elif slot == 'day': + same_people_domain = get_same_people_domain(user_goal, dom, slot) + if same_people_domain is None: + slot_strings.append('on {}'.format(self.boldify(state[book][slot]))) + else: + slot_strings.append( + self.boldify('on the same day as the {} booking'.format(same_people_domain))) + elif slot == 'stay': + slot_strings.append('for {} nights'.format(self.boldify(state[book][slot]))) + del state[book][slot] + + assert len(state[book]) <= 0, state[book] + + if len(slot_strings) > 0: + message.append(templates[dom]['book'].format(' '.join(slot_strings))) + + # fail_book + if 'fail_book' in user_goal[dom]: + adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['book'].items(), + user_goal[dom]['fail_book'].items())))[0][0][0] + + if adjusted_slot in ['internet', 'parking']: + message.append( + templates[dom]['fail_book ' + adjusted_slot + ' ' + user_goal[dom]['book'][adjusted_slot]]) + else: + message.append(templates[dom]['fail_book ' + adjusted_slot].format( + self.boldify(user_goal[dom]['book'][adjusted_slot]))) + + if boldify == do_boldify: + for i, m in enumerate(message): + message[i] = message[i].replace('wifi', "<b>wifi</b>") + message[i] = message[i].replace('internet', "<b>internet</b>") + message[i] = message[i].replace('parking', "<b>parking</b>") + + return message
+ + +if __name__ == "__main__": + goal_generator = GoalGenerator(corpus_path='data/multiwoz/annotated_user_da_with_span_full.json') + while True: + user_goal = goal_generator.get_user_goal() + print(user_goal) + # message = goal_generator.build_message(user_goal) + # pprint(message) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/usr/user.html b/docs/build/html/_modules/convlab/modules/usr/user.html new file mode 100644 index 0000000..d763e21 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/usr/user.html @@ -0,0 +1,250 @@ + + + + + + + + + + + convlab.modules.usr.user — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.usr.user
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.usr.user

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+
+
[docs]class UserSimulator: + """An aggregation of user simulator components.""" + def __init__(self, nlu_model, policy, nlg_model): + """ + The constructor of UserSimulator class. The input are the models of each component. + Args: + nlu_model (NLU): An instance of NLU class. + policy (UserPolicy): An instance of Policy class. + nlg_model (NLG): An instance of NLG class. + """ + self.nlu_model = nlu_model + # self.tracker = tracker + self.policy = policy + self.nlg_model = nlg_model + + self.sys_act = None + self.current_action = None + self.policy.init_session() + +
[docs] def response(self, input, context=[]): + """ + Generate the response of user. + Args: + input (str or dict): Preorder system output. The type is str if system.nlg is not None, else dict. + Returns: + output (str or dict): User response. If the nlg component is None, type(output) == dict, else str. + action (dict): The dialog act of output. Note that if the nlg component is None, the output and action are + identical. + session_over (boolean): True to terminate session, else session continues. + reward (float): The reward given by the user. + """ + + if self.nlu_model is not None: + sys_act = self.nlu_model.parse(input, context) + else: + sys_act = input + self.sys_act = sys_act + action, session_over, reward = self.policy.predict(None, sys_act) + if self.nlg_model is not None: + output = self.nlg_model.generate(action) + else: + output = action + + self.current_action = action + + return output, action, session_over, reward
+ +
[docs] def init_session(self): + """Init the parameters for a new session by calling the init_session methods of policy component.""" + self.policy.init_session() + self.current_action = None
+ +
[docs] def init_response(self): + """Return a init response of the user.""" + if self.nlg_model is not None: + output = self.nlg_model.generate({}) + else: + output = {} + return output
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/evaluate.html b/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/evaluate.html new file mode 100644 index 0000000..97a58a8 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/evaluate.html @@ -0,0 +1,382 @@ + + + + + + + + + + + convlab.modules.word_dst.multiwoz.evaluate — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_dst.multiwoz.evaluate
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_dst.multiwoz.evaluate

+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import json
+
+from convlab.modules.dst.multiwoz.dst_util import minDistance
+from convlab.modules.word_dst.multiwoz.mdbt import MDBTTracker
+
+
+
[docs]class Word_DST: + """A temporary semi-finishingv agent for word_dst testing, which takes as input utterances and output dialog state.""" + def __init__(self): + self.dst = MDBTTracker(data_dir='../../../../data/mdbt') + self.nlu = None + +
[docs] def update(self, action, observation): + # update history + self.dst.state['history'].append([str(action)]) + + # NLU parsing + input_act = self.nlu.parse(observation, sum(self.dst.state['history'], [])) if self.nlu else observation + + # state tracking + self.dst.update(input_act) + self.dst.state['history'][-1].append(observation) + + # update history + return self.dst.state
+ +
[docs] def reset(self): + self.dst.init_session()
+ +
[docs]def load_data(path='../../../../data/multiwoz/test.json'): + """Load data (mainly for testing data).""" + data = json.load(open(path)) + result = [] + for id, session in data.items(): + log = session['log'] + turn_data = ['null'] + goal = session['goal'] + session_data = [] + for turn_idx, turn in enumerate(log): + if turn_idx % 2 == 0: # user + observation = turn['text'] + turn_data.append(observation) + else: # system + action = turn['text'] + golden_state = turn['metadata'] + turn_data.append(golden_state) + session_data.append(turn_data) + turn_data = [action] + result.append([session_data, goal]) + return result
+ +
[docs]def run_test(): + agent = Word_DST() + agent.reset() + + test_data = load_data() + test_result = [] + for session_data, goal in test_data: + session_result = [] + for action, observation, golden_state in session_data: + pred_state = agent.update(action, observation) + session_result.append([golden_state, pred_state['belief_state']]) + test_result.append([session_result, goal]) + agent.reset() + return test_result
+ +
[docs]class ResultStat: + """A functional class for accuracy statistic.""" + def __init__(self): + self.stat = {} + +
[docs] def add(self, domain, slot, score): + """score = 1 or 0""" + if domain in self.stat: + if slot in self.stat[domain]: + self.stat[domain][slot][0] += score + self.stat[domain][slot][1] += 1. + else: + self.stat[domain][slot] = [score, 1.] + else: + self.stat[domain] = {slot: [score, 1.]}
+ +
[docs] def domain_acc(self, domain): + domain_stat = self.stat[domain] + ret = [0, 0] + for _, r in domain_stat.items(): + ret[0] += r[0] + ret[1] += r[1] + return ret[0]/(ret[1] + 1e-10)
+ +
[docs] def slot_acc(self, domain, slot): + slot_stat = self.stat[domain][slot] + return slot_stat[0] / (slot_stat[1] + 1e-10)
+ +
[docs] def all_acc(self): + acc_result = {} + for domain in self.stat: + acc_result[domain] = {} + acc_result[domain]['acc'] = self.domain_acc(domain) + for slot in self.stat[domain]: + acc_result[domain][slot+'_acc'] = self.slot_acc(domain, slot) + return json.dumps(acc_result, indent=4)
+ +
[docs]def evaluate(test_result): + stat = ResultStat() + session_level = [0., 0.] + for session, goal in test_result: + last_pred_state = None + for golden_state, pred_state in session: # session + last_pred_state = pred_state + domains = golden_state.keys() + for domain in domains: # domain + if domain == 'bus': + continue + assert domain in pred_state, 'domain: {}'.format(domain) + golden_domain, pred_domain = golden_state[domain], pred_state[domain] + for slot, value in golden_domain['semi'].items(): # slot + if _is_empty(slot, golden_domain['semi']): + continue + pv = pred_domain['semi'][slot] if slot in pred_domain['semi'] else '_None' + score = 0. + if _is_match(value, pv): + score = 1. + stat.add(domain, slot, score) + if match_goal(last_pred_state, goal): + session_level[0] += 1 + session_level[1] += 1 + print('domain and slot-level acc:') + print(stat.all_acc()) + print('session-level acc: {}'.format(convert2acc(session_level[0], session_level[1])))
+ +
[docs]def convert2acc(a, b): + if b == 0: + return -1 + return a/b
+ +
[docs]def match_goal(pred_state, goal): + domains = pred_state.keys() + for domain in domains: + if domain not in goal: + continue + goal_domain = goal[domain] + if 'info' not in goal_domain: + continue + goal_domain_info = goal_domain['info'] + for slot, value in goal_domain_info.items(): + if slot in pred_state[domain]['semi']: + v = pred_state[domain]['semi'][slot] + else: + return False + if _is_match(value, v): + continue + elif _fuzzy_match(value, v): + continue + else: + return False + return True
+ +def _is_empty(slot, domain_state): + if slot not in domain_state: + return True + value = domain_state[slot] + if value is None or value == "" or value == 'null': + return True + return False + +def _is_match(value1, value2): + if not isinstance(value1, str) or not isinstance(value2, str): + return value1 == value2 + value1 = value1.lower() + value2 = value2.lower() + value1 = ' '.join(value1.strip().split()) + value2 = ' '.join(value2.strip().split()) + if value1 == value2: + return True + return False + +def _fuzzy_match(value1, value2): + if not isinstance(value1, str) or not isinstance(value2, str): + return value1 == value2 + value1 = value1.lower() + value2 = value2.lower() + value1 = ' '.join(value1.strip().split()) + value2 = ' '.join(value2.strip().split()) + d = minDistance(value1, value2) + if (len(value1) >= 10 and d <= 2) or (len(value1) >= 15 and d <= 3): + return True + return False + +if __name__ == '__main__': + test_result = run_test() + json.dump(test_result, open('word_dst_test_result.json', 'w+'), indent=2) + print('test session num: {}'.format(len(test_result))) + evaluate(test_result) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/mdbt.html b/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/mdbt.html new file mode 100644 index 0000000..693ffe3 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/mdbt.html @@ -0,0 +1,676 @@ + + + + + + + + + + + convlab.modules.word_dst.multiwoz.mdbt — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_dst.multiwoz.mdbt
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_dst.multiwoz.mdbt

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import copy
+import json
+import math
+import os
+import sys
+import time
+from random import shuffle
+
+import numpy as np
+import tensorflow as tf
+
+from convlab.modules.dst.multiwoz.dst_util import init_state, init_belief_state, normalize_value
+from convlab.modules.dst.state_tracker import Tracker
+from convlab.modules.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
+from convlab.modules.word_dst.multiwoz.mdbt_util import model_definition, load_word_vectors, load_ontology, \
+    load_woz_data, \
+    track_dialogue, generate_batch, process_history, evaluate_model
+
+# DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), 'data/mdbt')
+# VALIDATION_URL = os.path.join(DATA_PATH, "data/validate.json")
+# WORD_VECTORS_URL = os.path.join(DATA_PATH, "word-vectors/paragram_300_sl999.txt")
+# TRAINING_URL = os.path.join(DATA_PATH, "data/train.json")
+# ONTOLOGY_URL = os.path.join(DATA_PATH, "data/ontology.json")
+# TESTING_URL = os.path.join(DATA_PATH, "data/test.json")
+# MODEL_URL = os.path.join(DATA_PATH, "models/model-1")
+# GRAPH_URL = os.path.join(DATA_PATH, "graphs/graph-1")
+# RESULTS_URL = os.path.join(DATA_PATH, "results/log-1.txt")
+# KB_URL = os.path.join(DATA_PATH, "data/")  # TODO: yaoqin
+# TRAIN_MODEL_URL = os.path.join(DATA_PATH, "train_models/model-1")
+# TRAIN_GRAPH_URL = os.path.join(DATA_PATH, "train_graph/graph-1")
+
+train_batch_size = 1
+batches_per_eval = 10
+no_epochs = 600
+device = "gpu"
+start_batch = 0
+
+
+
[docs]class MDBTTracker(Tracker): + """ + A multi-domain belief tracker, adopted from https://github.com/osmanio2/multi-domain-belief-tracking. + """ + def __init__(self, data_dir='data/mdbt'): + Tracker.__init__(self) + # data profile + self.data_dir = data_dir + self.validation_url = os.path.join(self.data_dir, 'data/validate.json') + self.word_vectors_url = os.path.join(self.data_dir, 'word-vectors/paragram_300_sl999.txt') + self.training_url = os.path.join(self.data_dir, 'data/train.json') + self.ontology_url = os.path.join(self.data_dir, 'data/ontology.json') + self.testing_url = os.path.join(self.data_dir, 'data/test.json') + self.model_url = os.path.join(self.data_dir, 'models/model-1') + self.graph_url = os.path.join(self.data_dir, 'graphs/graph-1') + self.results_url = os.path.join(self.data_dir, 'results/log-1.txt') + self.kb_url = os.path.join(self.data_dir, 'data/') # not used + self.train_model_url = os.path.join(self.data_dir, 'train_models/model-1') + self.train_graph_url = os.path.join(self.data_dir, 'train_graph/graph-1') + + print('Configuring MDBT model...') + self.word_vectors = load_word_vectors(self.word_vectors_url) + + # Load the ontology and extract the feature vectors + self.ontology, self.ontology_vectors, self.slots = load_ontology(self.ontology_url, self.word_vectors) + + # Load and process the training data + self.dialogues, self.actual_dialogues = load_woz_data(self.testing_url, self.word_vectors, self.ontology) + self.no_dialogues = len(self.dialogues) + + self.model_variables = model_definition(self.ontology_vectors, len(self.ontology), self.slots, num_hidden=None, + bidir=True, net_type=None, test=True, dev='cpu') + self.state = init_state() + _config = tf.ConfigProto() + _config.gpu_options.allow_growth = True + _config.allow_soft_placement = True + self.sess = tf.Session(config=_config) + self.param_restored = False + self.det_dic = {} + for domain, dic in REF_USR_DA.items(): + for key, value in dic.items(): + assert '-' not in key + self.det_dic[key.lower()] = key + '-' + domain + self.det_dic[value.lower()] = key + '-' + domain + self.value_dict = json.load(open(os.path.join(self.data_dir, '../multiwoz/value_dict.json'))) + +
[docs] def init_session(self): + self.state = init_state() + if not self.param_restored: + self.restore()
+ +
[docs] def restore(self): + self.restore_model(self.sess, tf.train.Saver())
+ +
[docs] def update(self, user_act=None): + """Update the dialog state.""" + if not isinstance(user_act, str): + raise Exception('Expected user_act to be <class \'str\'> type, but get {}.'.format(type(user_act))) + prev_state = self.state + if not os.path.exists(os.path.join(self.data_dir, "results")): + os.makedirs(os.path.join(self.data_dir, "results")) + + global train_batch_size + + model_variables = self.model_variables + (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy, + slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, predictions, + true_predictions, [y, _]) = model_variables + + # generate fake dialogue based on history (this os to reuse the original MDBT code) + # actual_history = prev_state['history'] # [[sys, user], [sys, user], ...] + actual_history = copy.deepcopy(prev_state['history']) # [[sys, user], [sys, user], ...] + actual_history[-1].append(user_act) + actual_history = self.normalize_history(actual_history) + if len(actual_history) == 0: + actual_history = [['', user_act if len(user_act)>0 else 'fake user act']] + fake_dialogue = {} + turn_no = 0 + for _sys, _user in actual_history: + turn = {} + turn['system'] = _sys + fake_user = {} + fake_user['text'] = _user + fake_user['belief_state'] = init_belief_state + turn['user'] = fake_user + key = str(turn_no) + fake_dialogue[key] = turn + turn_no += 1 + context, actual_context = process_history([fake_dialogue], self.word_vectors, self.ontology) + batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \ + batch_no_turns = generate_batch(context, 0, 1, len(self.ontology)) # old feature + + # run model + [pred, y_pred] = self.sess.run( + [predictions, y], + feed_dict={user: batch_user, sys_res: batch_sys, + labels: batch_labels, + domain_labels: batch_domain_labels, + user_uttr_len: batch_user_uttr_len, + sys_uttr_len: batch_sys_uttr_len, + no_turns: batch_no_turns, + keep_prob: 1.0}) + + # convert to str output + dialgs, _, _ = track_dialogue(actual_context, self.ontology, pred, y_pred) + assert len(dialgs) >= 1 + last_turn = dialgs[0][-1] + predictions = last_turn['prediction'] + new_belief_state = copy.deepcopy(prev_state['belief_state']) + + # updaet belief state + for item in predictions: + item = item.lower() + domain, slot, value = item.strip().split('-') + value = value[::-1].split(':', 1)[1][::-1] + if slot == 'price range': + slot = 'pricerange' + if slot not in ['name', 'book']: + if domain not in new_belief_state: + raise Exception('Error: domain <{}> not in belief state'.format(domain)) + slot = REF_SYS_DA[domain.capitalize( )].get(slot, slot) + assert 'semi' in new_belief_state[domain] + assert 'book' in new_belief_state[domain] + if 'book' in slot: + assert slot.startswith('book ') + slot = slot.strip().split()[1] + domain_dic = new_belief_state[domain] + if slot in domain_dic['semi']: + new_belief_state[domain]['semi'][slot] = normalize_value(self.value_dict, domain, slot, value) + elif slot in domain_dic['book']: + new_belief_state[domain]['book'][slot] = value + elif slot.lower() in domain_dic['book']: + new_belief_state[domain]['book'][slot.lower()] = value + else: + with open('mdbt_unknown_slot.log', 'a+') as f: + f.write('unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(slot, value, + domain, item)) + new_request_state = copy.deepcopy(prev_state['request_state']) + # update request_state + user_request_slot = self.detect_requestable_slots(user_act) + for domain in user_request_slot: + for key in user_request_slot[domain]: + if domain not in new_request_state: + new_request_state[domain] = {} + if key not in new_request_state[domain]: + new_request_state[domain][key] = user_request_slot[domain][key] + # update state + new_state = copy.deepcopy(dict(prev_state)) + new_state['belief_state'] = new_belief_state + new_state['request_state'] = new_request_state + self.state = new_state + return self.state
+ +
[docs] def normalize_history(self, history): + """Replace zero-length history.""" + for i in range(len(history)): + a, b = history[i] + if len(a) == 0: + history[i][0] = 'sys' + if len(b) == 0: + history[i][1] = 'user' + return history
+ +
[docs] def detect_requestable_slots(self, observation): + result = {} + observation = observation.lower() + _observation = ' {} '.format(observation) + for value in self.det_dic.keys(): + _value = ' {} '.format(value.strip()) + if _value in _observation: + key, domain = self.det_dic[value].split('-') + if domain not in result: + result[domain] = {} + result[domain][key] = 0 + return result
+ +
[docs] def restore_model(self, sess, saver): + saver.restore(sess, self.model_url) + print('Loading trained MDBT model from ', self.model_url) + self.param_restored = True
+ +
[docs] def train(self): + """ + Train the model. + Model saved to + """ + num_hid, bidir, net_type, n2p, batch_size, model_url, graph_url, dev = \ + None, True, None, None, None, None, None, None + global train_batch_size, MODEL_URL, GRAPH_URL, device, TRAIN_MODEL_URL, TRAIN_GRAPH_URL + + if batch_size: + train_batch_size = batch_size + print("Setting up the batch size to {}.........................".format(batch_size)) + if model_url: + TRAIN_MODEL_URL = model_url + print("Setting up the model url to {}.........................".format(TRAIN_MODEL_URL)) + if graph_url: + TRAIN_GRAPH_URL = graph_url + print("Setting up the graph url to {}.........................".format(TRAIN_GRAPH_URL)) + + if dev: + device = dev + print("Setting up the device to {}.........................".format(device)) + + # 1 Load and process the input data including the ontology + # Load the word embeddings + word_vectors = load_word_vectors(self.word_vectors_url) + + # Load the ontology and extract the feature vectors + ontology, ontology_vectors, slots = load_ontology(self.ontology_url, word_vectors) + + # Load and process the training data + dialogues, _ = load_woz_data(self.training_url, word_vectors, ontology) + no_dialogues = len(dialogues) + + # Load and process the validation data + val_dialogues, _ = load_woz_data(self.validation_url, word_vectors, ontology) + + # Generate the validation batch data + val_data = generate_batch(val_dialogues, 0, len(val_dialogues), len(ontology)) + val_iterations = int(len(val_dialogues) / train_batch_size) + + # 2 Initialise and set up the model graph + # Initialise the model + graph = tf.Graph() + with graph.as_default(): + model_variables = model_definition(ontology_vectors, len(ontology), slots, num_hidden=num_hid, bidir=bidir, + net_type=net_type, dev=device) + (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy, + slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, _, _, _) = model_variables + [precision, recall, value_f1] = value_f1 + saver = tf.train.Saver() + if device == 'gpu': + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + else: + config = tf.ConfigProto(device_count={'GPU': 0}) + + sess = tf.Session(config=config) + if os.path.exists(TRAIN_MODEL_URL + ".index"): + saver.restore(sess, TRAIN_MODEL_URL) + print("Loading from an existing model {} ....................".format(TRAIN_MODEL_URL)) + else: + if not os.path.exists(TRAIN_MODEL_URL): + os.makedirs('/'.join(TRAIN_MODEL_URL.split('/')[:-1])) + os.makedirs('/'.join(TRAIN_GRAPH_URL.split('/')[:-1])) + init = tf.global_variables_initializer() + sess.run(init) + print("Create new model parameters.....................................") + merged = tf.summary.merge_all() + val_accuracy = tf.summary.scalar('validation_accuracy', value_accuracy) + val_f1 = tf.summary.scalar('validation_f1_score', value_f1) + train_writer = tf.summary.FileWriter(TRAIN_GRAPH_URL, graph) + train_writer.flush() + + # 3 Perform an epoch of training + last_update = -1 + best_f_score = -1 + for epoch in range(no_epochs): + + batch_size = train_batch_size + sys.stdout.flush() + iterations = math.ceil(no_dialogues / train_batch_size) + start_time = time.time() + val_i = 0 + shuffle(dialogues) + for batch_id in range(iterations): + + if batch_id == iterations - 1 and no_dialogues % iterations != 0: + batch_size = no_dialogues % train_batch_size + + batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \ + batch_no_turns = generate_batch(dialogues, batch_id, batch_size, len(ontology)) + + [_, summary, da, sa, va, vf, pr, re] = sess.run([train_step, merged, domain_accuracy, slot_accuracy, + value_accuracy, value_f1, precision, recall], + feed_dict={user: batch_user, sys_res: batch_sys, + labels: batch_labels, + domain_labels: batch_domain_labels, + user_uttr_len: batch_user_uttr_len, + sys_uttr_len: batch_sys_uttr_len, + no_turns: batch_no_turns, + keep_prob: 0.5}) + + print("The accuracies for domain is {:.2f}, slot {:.2f}, value {:.2f}, f1_score {:.2f} precision {:.2f}" + " recall {:.2f} for batch {}".format(da, sa, va, vf, pr, re, batch_id + iterations * epoch)) + + train_writer.add_summary(summary, start_batch + batch_id + iterations * epoch) + + # ================================ VALIDATION ============================================== + + if batch_id % batches_per_eval == 0 or batch_id == 0: + if batch_id == 0: + print("Batch", "0", "to", batch_id, "took", round(time.time() - start_time, 2), "seconds.") + + else: + print("Batch", batch_id + iterations * epoch - batches_per_eval, "to", + batch_id + iterations * epoch, "took", + round(time.time() - start_time, 3), "seconds.") + start_time = time.time() + + _, _, v_acc, f1_score, sm1, sm2 = evaluate_model(sess, model_variables, val_data, + [val_accuracy, val_f1], batch_id, val_i) + val_i += 1 + val_i %= val_iterations + train_writer.add_summary(sm1, start_batch + batch_id + iterations * epoch) + train_writer.add_summary(sm2, start_batch + batch_id + iterations * epoch) + stime = time.time() + current_metric = f1_score + print(" Validation metric:", round(current_metric, 5), " eval took", + round(time.time() - stime, 2), "last update at:", last_update, "/", iterations) + + # and if we got a new high score for validation f-score, we need to save the parameters: + if current_metric > best_f_score: + last_update = batch_id + iterations * epoch + 1 + print("\n ====================== New best validation metric:", round(current_metric, 4), + " - saving these parameters. Batch is:", last_update, "/", iterations, + "---------------- =========== \n") + + best_f_score = current_metric + + saver.save(sess, TRAIN_MODEL_URL) + + print("The best parameters achieved a validation metric of", round(best_f_score, 4))
+ +
[docs] def test(self, sess): + """Test the MDBT model on mdbt dataset. Almost the same as original code.""" + if not os.path.exists("../../data/mdbt/results"): + os.makedirs("../../data/mdbt/results") + + global train_batch_size, MODEL_URL, GRAPH_URL + + model_variables = self.model_variables + (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy, + slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, predictions, + true_predictions, [y, _]) = model_variables + [precision, recall, value_f1] = value_f1 + # print("\tMDBT: Loading from an existing model {} ....................".format(MODEL_URL)) + + iterations = math.ceil(self.no_dialogues / train_batch_size) + batch_size = train_batch_size + [slot_acc, tot_accuracy] = [np.zeros(len(self.ontology), dtype="float32"), 0] + slot_accurac = 0 + # value_accurac = np.zeros((len(slots),), dtype="float32") + value_accurac = 0 + joint_accuracy = 0 + f1_score = 0 + preci = 0 + recal = 0 + processed_dialogues = [] + # np.set_printoptions(threshold=np.nan) + for batch_id in range(int(iterations)): + + if batch_id == iterations - 1: + batch_size = self.no_dialogues - batch_id * train_batch_size + + batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \ + batch_no_turns = generate_batch(self.dialogues, batch_id, batch_size, len(self.ontology)) + + [da, sa, va, vf, pr, re, pred, true_pred, y_pred] = sess.run( + [domain_accuracy, slot_accuracy, value_accuracy, + value_f1, precision, recall, predictions, + true_predictions, y], + feed_dict={user: batch_user, sys_res: batch_sys, + labels: batch_labels, + domain_labels: batch_domain_labels, + user_uttr_len: batch_user_uttr_len, + sys_uttr_len: batch_sys_uttr_len, + no_turns: batch_no_turns, + keep_prob: 1.0}) + + true = sum([1 if np.array_equal(pred[k, :], true_pred[k, :]) and sum(true_pred[k, :]) > 0 else 0 + for k in range(true_pred.shape[0])]) + actual = sum([1 if sum(true_pred[k, :]) > 0 else 0 for k in range(true_pred.shape[0])]) + ja = true / actual + tot_accuracy += da + # joint_accuracy += ja + slot_accurac += sa + if math.isnan(pr): + pr = 0 + preci += pr + recal += re + if math.isnan(vf): + vf = 0 + f1_score += vf + # value_accurac += va + slot_acc += np.mean(np.asarray(np.equal(pred, true_pred), dtype="float32"), axis=0) + + dialgs, va1, ja = track_dialogue(self.actual_dialogues[batch_id * train_batch_size: + batch_id * train_batch_size + batch_size], + self.ontology, pred, y_pred) + processed_dialogues += dialgs + joint_accuracy += ja + value_accurac += va1 + + print( + "The accuracies for domain is {:.2f}, slot {:.2f}, value {:.2f}, other value {:.2f}, f1_score {:.2f} precision {:.2f}" + " recall {:.2f} for batch {}".format(da, sa, np.mean(va), va1, vf, pr, re, batch_id)) + + print( + "End of evaluating the test set...........................................................................") + + slot_acc /= iterations + # print("The accuracies for each slot:") + # print(value_accurac/iterations) + print("The overall accuracies for domain is" + " {}, slot {}, value {}, f1_score {}, precision {}," + " recall {}, joint accuracy {}".format(tot_accuracy / iterations, slot_accurac / iterations, + value_accurac / iterations, f1_score / iterations, + preci / iterations, recal / iterations, + joint_accuracy / iterations)) + + with open(self.results_url, 'w') as f: + json.dump(processed_dialogues, f, indent=4)
+ + +
[docs]def test_update(): + os.environ["CUDA_VISIBLE_DEVICES"] = '0' + _config = tf.ConfigProto() + _config.gpu_options.allow_growth = True + _config.allow_soft_placement = True + start_time = time.time() + mdbt = MDBTTracker() + print('\tMDBT: model build time: {:.2f} seconds'.format(time.time() - start_time)) + saver = tf.train.Saver() + mdbt.restore_model(mdbt.sess, saver) + # demo state history + mdbt.state['history'] = [['null', 'I\'m trying to find an expensive restaurant in the centre part of town.'], + [ + 'The Cambridge Chop House is an good expensive restaurant in the centre of town. Would you like me to book it for you?', + 'Yes, a table for 1 at 16:15 on sunday. I need the reference number.']] + new_state = mdbt.update(None, 'hi, this is not good') + print(json.dumps(new_state, indent=4)) + print('all time: {:.2f} seconds'.format(time.time() - start_time))
+ + +
[docs]def evaluate_model(): + os.environ["CUDA_VISIBLE_DEVICES"] = '0' + _config = tf.ConfigProto() + _config.gpu_options.allow_growth = True + _config.allow_soft_placement = True + start_time = time.time() + mdbt = MDBTTracker() + print('\tMDBT: model build time: {:.2f} seconds'.format(time.time() - start_time)) + saver = tf.train.Saver() + mdbt.restore_model(mdbt.sess, saver) + mdbt.test(mdbt.sess)
+ +if __name__ == '__main__': + evaluate_model() +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/mdbt_util.html b/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/mdbt_util.html new file mode 100644 index 0000000..b453025 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_dst/multiwoz/mdbt_util.html @@ -0,0 +1,1199 @@ + + + + + + + + + + + convlab.modules.word_dst.multiwoz.mdbt_util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_dst.multiwoz.mdbt_util
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_dst.multiwoz.mdbt_util

+# -*- coding: utf-8 -*-
+
+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import json
+import math
+import os
+import time
+from collections import OrderedDict
+from copy import deepcopy
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.client import device_lib
+
+DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), 'data/mdbt')
+VALIDATION_URL = os.path.join(DATA_PATH, "data/validate.json")
+WORD_VECTORS_URL = os.path.join(DATA_PATH, "word-vectors/paragram_300_sl999.txt")
+TRAINING_URL = os.path.join(DATA_PATH, "data/train.json")
+ONTOLOGY_URL = os.path.join(DATA_PATH, "data/ontology.json")
+TESTING_URL = os.path.join(DATA_PATH, "data/test.json")
+MODEL_URL = os.path.join(DATA_PATH, "models/model-1")
+GRAPH_URL = os.path.join(DATA_PATH, "graphs/graph-1")
+RESULTS_URL = os.path.join(DATA_PATH, "results/log-1.txt")
+
+#ROOT_URL = '../../data/mdbt'
+
+#VALIDATION_URL = "./data/mdbt/data/validate.json"
+#WORD_VECTORS_URL = "./data/mdbt/word-vectors/paragram_300_sl999.txt"
+#TRAINING_URL = "./data/mdbt/data/train.json"
+#ONTOLOGY_URL = "./data/mdbt/data/ontology.json"
+#TESTING_URL = "./data/mdbt/data/test.json"
+#MODEL_URL = "./data/mdbt/models/model-1"
+#GRAPH_URL = "./data/mdbt/graphs/graph-1"
+#RESULTS_URL = "./data/mdbt/results/log-1.txt"
+
+
+domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi']
+
+train_batch_size = 64
+batches_per_eval = 10
+no_epochs = 600
+device = "gpu"
+start_batch = 0
+
+num_slots = 0
+
+booking_slots = {}
+
+network = "lstm"
+bidirect = True
+lstm_num_hidden = 50
+max_utterance_length = 50
+vector_dimension = 300
+max_no_turns = 22
+
+
+# model.py
+
[docs]def get_available_devs(): + local_device_protos = device_lib.list_local_devices() + return [x.name for x in local_device_protos if x.device_type == 'GPU']
+ + +
[docs]class GRU(tf.nn.rnn_cell.RNNCell): + ''' + Create a Gated Recurrent unit to unroll the network through time + for combining the current and previous belief states + ''' + + def __init__(self, W_h, U_h, M_h, W_m, U_m, label_size, reuse=None, binary_output=False): + super(GRU, self).__init__(_reuse=reuse) + self.label_size = label_size + self.M_h = M_h + self.W_m = W_m + self.U_m = U_m + self.U_h = U_h + self.W_h = W_h + self.binary_output = binary_output + + def __call__(self, inputs, state, scope=None): + state_only = tf.slice(state, [0, self.label_size], [-1, -1]) + output_only = tf.slice(state, [0, 0], [-1, self.label_size]) + new_state = tf.tanh(tf.matmul(inputs, self.U_m) + tf.matmul(state_only, self.W_m)) + output = tf.matmul(inputs, self.U_h) + tf.matmul(output_only, self.W_h) + tf.matmul(state_only, self.M_h) + if self.binary_output: + output_ = tf.sigmoid(output) + else: + output_ = tf.nn.softmax(output) + state = tf.concat([output_, new_state], 1) + return output, state + + @property + def state_size(self): + return tf.shape(self.W_m)[0] + self.label_size + + @property + def output_size(self): + return tf.shape(self.W_h)[0]
+ + +
[docs]def define_CNN_model(utter, num_filters=300, name="r"): + """ + Better code for defining the CNN model. + """ + filter_sizes = [1, 2, 3] + W = [] + b = [] + for i, filter_size in enumerate(filter_sizes): + filter_shape = [filter_size, vector_dimension, 1, num_filters] + W.append(tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="F_W")) + b.append(tf.Variable(tf.constant(0.1, shape=[num_filters]), name="F_b")) + + utter = tf.reshape(utter, [-1, max_utterance_length, vector_dimension]) + + hidden_representation = tf.zeros([num_filters], tf.float32) + + pooled_outputs = [] + for i, filter_size in enumerate(filter_sizes): + # with tf.name_scope("conv-maxpool-%s" % filter_size): + # Convolution Layer + conv = tf.nn.conv2d( + tf.expand_dims(utter, -1), + W[i], + strides=[1, 1, 1, 1], + padding="VALID", + name="conv_R") + # Apply nonlinearity + h = tf.nn.relu(tf.nn.bias_add(conv, b[i]), name="relu") + # Maxpooling over the outputs + pooled = tf.nn.max_pool( + h, + ksize=[1, max_utterance_length - filter_size + 1, 1, 1], + strides=[1, 1, 1, 1], + padding='VALID', + name="r_") + pooled_outputs.append(pooled) + + hidden_representation += tf.reshape(tf.concat(pooled, 3), [-1, num_filters]) + + hidden_representation = tf.reshape(hidden_representation, [-1, max_no_turns, num_filters], name=name) + + return hidden_representation
+ + +
[docs]def lstm_model(text_input, utterance_length, num_hidden, name, net_type, bidir): + ''' + Define an Lstm model that will run across the user input and system act + :param text_input: [batch_size, max_num_turns, max_utterance_size, vector_dimension] + :param utterance_length: number words in every utterance [batch_size, max_num_turns, 1] + :param num_hidden: -- int -- + :param name: The name of lstm network + :param net_type: type of the network ("lstm" or "gru" or "rnn") + :param bidir: use a bidirectional network -- bool -- + :return: output at each state [batch_size, max_num_turns, max_utterance_size, num_hidden], + output of the final state [batch_size, max_num_turns, num_hidden] + ''' + with tf.variable_scope(name): + + text_input = tf.reshape(text_input, [-1, max_utterance_length, vector_dimension]) + utterance_length = tf.reshape(utterance_length, [-1]) + + def rnn(net_typ, num_units): + if net_typ == "lstm": + return tf.nn.rnn_cell.LSTMCell(num_units) + elif net_typ == "gru": + return tf.nn.rnn_cell.GRUCell(num_units) + else: + return tf.nn.rnn_cell.BasicRNNCell(num_units) + + if bidir: + assert num_hidden % 2 == 0 + rev_cell = rnn(net_type, num_hidden // 2) + cell = rnn(net_type, num_hidden // 2) + _, lspd = tf.nn.bidirectional_dynamic_rnn(cell, rev_cell, text_input, dtype=tf.float32, + sequence_length=utterance_length) + if net_type == "lstm": + lspd = (lspd[0].h, lspd[1].h) + + last_state = tf.concat(lspd, 1) + else: + cell = rnn(net_type, num_hidden) + _, last_state = tf.nn.dynamic_rnn(cell, text_input, dtype=tf.float32, sequence_length=utterance_length) + if net_type == "lstm": + last_state = last_state.h + + last_state = tf.reshape(last_state, [-1, max_no_turns, num_hidden]) + + return last_state
+ + +
[docs]def model_definition(ontology, num_slots, slots, num_hidden=None, net_type=None, bidir=None, test=False, dev=None): + ''' + Create neural belief tracker model that is defined in my notes. It consists of encoding the user and system input, + then use the ontology to decode the encoder in manner that detects if a domain-slot-value class is mentioned + :param ontology: numpy array of the embedded vectors of the ontology [num_slots, 3*vector_dimension] + :param num_slots: number of ontology classes --int-- + :param slots: indices of the values of each slot list of lists of ints + :param num_hidden: Number of hidden units or dimension of the hidden space + :param net_type: The type of the encoder network cnn, lstm, gru, rnn ...etc + :param bidir: For recurrent networks should it be bidirectional + :param test: This is testing mode (no back-propagation) + :param dev: Device to run the model on (cpu or gpu) + :return: All input variable/placeholders output metrics (precision, recall, f1-score) and trainer + ''' + # print('model definition') + # print(ontology, num_slots, slots, num_hidden, net_type, bidir, test, dev) + global lstm_num_hidden + + if not net_type: + net_type = network + else: + print("\tMDBT: Setting up the type of the network to {}..............................".format(net_type)) + if bidir == None: + bidir = bidirect + else: + pass + # print("\tMDBT: Setting up type of the recurrent network to bidirectional {}...........................".format(bidir)) + if num_hidden: + lstm_num_hidden = num_hidden + print("\tMDBT: Setting up type of the dimension of the hidden space to {}.........................".format(num_hidden)) + + ontology = tf.constant(ontology, dtype=tf.float32) + + # ----------------------------------- Define the input variables -------------------------------------------------- + user_input = tf.placeholder(tf.float32, [None, max_no_turns, max_utterance_length, vector_dimension], name="user") + system_input = tf.placeholder(tf.float32, [None, max_no_turns, max_utterance_length, vector_dimension], name="sys") + num_turns = tf.placeholder(tf.int32, [None], name="num_turns") + user_utterance_lengths = tf.placeholder(tf.int32, [None, max_no_turns], name="user_sen_len") + sys_utterance_lengths = tf.placeholder(tf.int32, [None, max_no_turns], name="sys_sen_len") + labels = tf.placeholder(tf.float32, [None, max_no_turns, num_slots], name="labels") + domain_labels = tf.placeholder(tf.float32, [None, max_no_turns, num_slots], name="domain_labels") + # dropout placeholder, 0.5 for training, 1.0 for validation/testing: + keep_prob = tf.placeholder("float") + + # ------------------------------------ Create the Encoder networks ------------------------------------------------ + devs = ['/device:CPU:0'] + if dev == 'gpu': + devs = get_available_devs() + + if net_type == "cnn": + with tf.device(devs[1 % len(devs)]): + # Encode the domain of the user input using a LSTM network + usr_dom_en = define_CNN_model(user_input, num_filters=lstm_num_hidden, name="h_u_d") + # Encode the domain of the system act using a LSTM network + sys_dom_en = define_CNN_model(system_input, num_filters=lstm_num_hidden, name="h_s_d") + + with tf.device(devs[2 % len(devs)]): + # Encode the slot of the user input using a CNN network + usr_slot_en = define_CNN_model(user_input, num_filters=lstm_num_hidden, name="h_u_s") + # Encode the slot of the system act using a CNN network + sys_slot_en = define_CNN_model(system_input, num_filters=lstm_num_hidden, name="h_s_s") + # Encode the value of the user input using a CNN network + usr_val_en = define_CNN_model(user_input, num_filters=lstm_num_hidden, name="h_u_v") + # Encode the value of the system act using a CNN network + sys_val_en = define_CNN_model(system_input, num_filters=lstm_num_hidden, name="h_s_v") + # Encode the user using a CNN network + usr_en = define_CNN_model(user_input, num_filters=lstm_num_hidden // 5, name="h_u") + + else: + + with tf.device(devs[1 % len(devs)]): + # Encode the domain of the user input using a LSTM network + usr_dom_en = lstm_model(user_input, user_utterance_lengths, lstm_num_hidden, "h_u_d", net_type, bidir) + usr_dom_en = tf.nn.dropout(usr_dom_en, keep_prob, name="h_u_d_out") + # Encode the domain of the system act using a LSTM network + sys_dom_en = lstm_model(system_input, sys_utterance_lengths, lstm_num_hidden, "h_s_d", net_type, bidir) + sys_dom_en = tf.nn.dropout(sys_dom_en, keep_prob, name="h_s_d_out") + + with tf.device(devs[2 % len(devs)]): + # Encode the slot of the user input using a LSTM network + usr_slot_en = lstm_model(user_input, user_utterance_lengths, lstm_num_hidden, "h_u_s", net_type, bidir) + usr_slot_en = tf.nn.dropout(usr_slot_en, keep_prob, name="h_u_s_out") + # Encode the slot of the system act using a LSTM network + sys_slot_en = lstm_model(system_input, sys_utterance_lengths, lstm_num_hidden, "h_s_s", net_type, bidir) + sys_slot_en = tf.nn.dropout(sys_slot_en, keep_prob, name="h_s_s_out") + # Encode the value of the user input using a LSTM network + usr_val_en = lstm_model(user_input, user_utterance_lengths, lstm_num_hidden, "h_u_v", net_type, bidir) + usr_val_en = tf.nn.dropout(usr_val_en, keep_prob, name="h_u_v_out") + # Encode the value of the system act using a LSTM network + sys_val_en = lstm_model(system_input, sys_utterance_lengths, lstm_num_hidden, "h_s_v", net_type, bidir) + sys_val_en = tf.nn.dropout(sys_val_en, keep_prob, name="h_s_v_out") + # Encode the user using a LSTM network + usr_en = lstm_model(user_input, user_utterance_lengths, lstm_num_hidden // 5, "h_u", net_type, bidir) + usr_en = tf.nn.dropout(usr_en, keep_prob, name="h_u_out") + + with tf.device(devs[1 % len(devs)]): + usr_dom_en = tf.tile(tf.expand_dims(usr_dom_en, axis=2), [1, 1, num_slots, 1], name="h_u_d") + sys_dom_en = tf.tile(tf.expand_dims(sys_dom_en, axis=2), [1, 1, num_slots, 1], name="h_s_d") + with tf.device(devs[2 % len(devs)]): + usr_slot_en = tf.tile(tf.expand_dims(usr_slot_en, axis=2), [1, 1, num_slots, 1], name="h_u_s") + sys_slot_en = tf.tile(tf.expand_dims(sys_slot_en, axis=2), [1, 1, num_slots, 1], name="h_s_s") + usr_val_en = tf.tile(tf.expand_dims(usr_val_en, axis=2), [1, 1, num_slots, 1], name="h_u_v") + sys_val_en = tf.tile(tf.expand_dims(sys_val_en, axis=2), [1, 1, num_slots, 1], name="h_s_v") + usr_en = tf.tile(tf.expand_dims(usr_en, axis=2), [1, 1, num_slots, 1], name="h_u") + + # All encoding vectors have size [batch_size, max_turns, num_slots, num_hidden] + + # Matrix that transforms the ontology from the embedding space to the hidden representation + with tf.device(devs[1 % len(devs)]): + W_onto_domain = tf.Variable(tf.random_normal([vector_dimension, lstm_num_hidden]), name="W_onto_domain") + W_onto_slot = tf.Variable(tf.random_normal([vector_dimension, lstm_num_hidden]), name="W_onto_slot") + W_onto_value = tf.Variable(tf.random_normal([vector_dimension, lstm_num_hidden]), name="W_onto_value") + + # And biases + b_onto_domain = tf.Variable(tf.zeros([lstm_num_hidden]), name="b_onto_domain") + b_onto_slot = tf.Variable(tf.zeros([lstm_num_hidden]), name="b_onto_slot") + b_onto_value = tf.Variable(tf.zeros([lstm_num_hidden]), name="b_onto_value") + + # Apply the transformation from the embedding space of the ontology to the hidden space + domain_vec = tf.slice(ontology, begin=[0, 0], size=[-1, vector_dimension]) + slot_vec = tf.slice(ontology, begin=[0, vector_dimension], size=[-1, vector_dimension]) + value_vec = tf.slice(ontology, begin=[0, 2 * vector_dimension], size=[-1, vector_dimension]) + # Each [num_slots, vector_dimension] + d = tf.nn.dropout(tf.tanh(tf.matmul(domain_vec, W_onto_domain) + b_onto_domain), keep_prob, name="d") + s = tf.nn.dropout(tf.tanh(tf.matmul(slot_vec, W_onto_slot) + b_onto_slot), keep_prob, name="s") + v = tf.nn.dropout(tf.tanh(tf.matmul(value_vec, W_onto_value) + b_onto_value), keep_prob, name="v") + # Each [num_slots, num_hidden] + + # Apply the comparison mechanism for all the user and system utterances and ontology values + domain_user = tf.multiply(usr_dom_en, d, name="domain_user") + domain_sys = tf.multiply(sys_dom_en, d, name="domain_sys") + slot_user = tf.multiply(usr_slot_en, s, name="slot_user") + slot_sys = tf.multiply(sys_slot_en, s, name="slot_sys") + value_user = tf.multiply(usr_val_en, v, name="value_user") + value_sys = tf.multiply(sys_val_en, v, name="value_sys") + # All of size [batch_size, max_turns, num_slots, num_hidden] + + # -------------- Domain Detection ------------------------------------------------------------------------- + W_domain = tf.Variable(tf.random_normal([2 * lstm_num_hidden]), name="W_domain") + b_domain = tf.Variable(tf.zeros([1]), name="b_domain") + y_d = tf.sigmoid(tf.reduce_sum(tf.multiply(tf.concat([domain_user, domain_sys], axis=3), W_domain), axis=3) + + b_domain) # [batch_size, max_turns, num_slots] + + # -------- Run through each of the 3 case ( inform, request, confirm) and decode the inferred state --------- + # 1 Inform (User is informing the system about the goal, e.g. "I am looking for a place to stay in the centre") + W_inform = tf.Variable(tf.random_normal([2 * lstm_num_hidden]), name="W_inform") + b_inform = tf.Variable(tf.random_normal([1]), name="b_inform") + inform = tf.add(tf.reduce_sum(tf.multiply(tf.concat([slot_user, value_user], axis=3), W_inform), axis=3), b_inform, + name="inform") # [batch_size, max_turns, num_slots] + + # 2 Request (The system is requesting information from the user, e.g. "what type of food would you like?") + with tf.device(devs[2 % len(devs)]): + W_request = tf.Variable(tf.random_normal([2 * lstm_num_hidden]), name="W_request") + b_request = tf.Variable(tf.random_normal([1]), name="b_request") + request = tf.add(tf.reduce_sum(tf.multiply(tf.concat([slot_sys, value_user], axis=3), W_request), axis=3), + b_request, name="request") # [batch_size, max_turns, num_slots] + + # 3 Confirm (The system is confirming values given by the user, e.g. "How about turkish food?") + with tf.device(devs[3 % len(devs)]): + size = 2 * lstm_num_hidden + lstm_num_hidden // 5 + W_confirm = tf.Variable(tf.random_normal([size]), name="W_confirm") + b_confirm = tf.Variable(tf.random_normal([1]), name="b_confirm") + confirm = tf.add( + tf.reduce_sum(tf.multiply(tf.concat([slot_sys, value_sys, usr_en], axis=3), W_confirm), axis=3), + b_confirm, name="confirm") # [batch_size, max_turns, num_slots] + + output = inform + request + confirm + + # -------------------- Adding the belief update RNN with memory cell (Taken from previous model) ------------------- + with tf.device(devs[2 % len(devs)]): + domain_memory = tf.Variable(tf.random_normal([1, 1]), name="domain_memory") + domain_current = tf.Variable(tf.random_normal([1, 1]), name="domain_current") + domain_M_h = tf.Variable(tf.random_normal([1, 1]), name="domain_M_h") + domain_W_m = tf.Variable(tf.random_normal([1, 1], name="domain_W_m")) + domain_U_m = tf.Variable(tf.random_normal([1, 1]), name="domain_U_m") + a_memory = tf.Variable(tf.random_normal([1, 1]), name="a_memory") + b_memory = tf.Variable(tf.random_normal([1, 1]), name="b_memory") + a_current = tf.Variable(tf.random_normal([1, 1]), name="a_current") + b_current = tf.Variable(tf.random_normal([1, 1]), name="b_current") + M_h_a = tf.Variable(tf.random_normal([1, 1]), name="M_h_a") + M_h_b = tf.Variable(tf.random_normal([1, 1]), name="M_h_b") + W_m_a = tf.Variable(tf.random_normal([1, 1]), name="W_m_a") + W_m_b = tf.Variable(tf.random_normal([1, 1]), name="W_m_b") + U_m_a = tf.Variable(tf.random_normal([1, 1]), name="U_m_a") + U_m_b = tf.Variable(tf.random_normal([1, 1]), name="U_m_b") + + # ---------------------------------- Unroll the domain over time -------------------------------------------------- + with tf.device(devs[1 % len(devs)]): + cell = GRU(domain_memory * tf.diag(tf.ones(num_slots)), domain_current * tf.diag(tf.ones(num_slots)), + domain_M_h * tf.diag(tf.ones(num_slots)), domain_W_m * tf.diag(tf.ones(num_slots)), + domain_U_m * tf.diag(tf.ones(num_slots)), num_slots, + binary_output=True) + + y_d, _ = tf.nn.dynamic_rnn(cell, y_d, sequence_length=num_turns, dtype=tf.float32) + + domain_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=domain_labels, logits=y_d), axis=2, + name="domain_loss") / (num_slots / len(slots)) + + y_d = tf.sigmoid(y_d) + + with tf.device(devs[0 % len(devs)]): + + loss = [None for _ in range(len(slots))] + slot_pred = [None for _ in range(len(slots))] + slot_label = [None for _ in range(len(slots))] + val_pred = [None for _ in range(len(slots))] + val_label = [None for _ in range(len(slots))] + y = [None for _ in range(len(slots))] + y_pred = [None for _ in range(len(slots))] + for i in range(len(slots)): + + num_values = slots[i] + 1 # For the none case + size = sum(slots[:i + 1]) - slots[i] + if test: + domain_output = tf.slice(tf.round(y_d), begin=[0, 0, size], size=[-1, -1, slots[i]]) + else: + domain_output = tf.slice(domain_labels, begin=[0, 0, size], size=[-1, -1, slots[i]]) + max_val = tf.expand_dims(tf.reduce_max(domain_output, axis=2), axis=2) + # tf.assert_less_equal(max_val, 1.0) + # tf.assert_equal(tf.round(max_val), max_val) + domain_output = tf.concat([tf.zeros(tf.shape(domain_output)), 1 - max_val], axis=2) + + slot_output = tf.slice(output, begin=[0, 0, size], size=[-1, -1, slots[i]]) + slot_output = tf.concat([slot_output, tf.zeros([tf.shape(output)[0], max_no_turns, 1])], axis=2) + + labels_output = tf.slice(labels, begin=[0, 0, size], size=[-1, -1, slots[i]]) + max_val = tf.expand_dims(tf.reduce_max(labels_output, axis=2), axis=2) + # tf.assert_less_equal(max_val, 1.0) + # tf.assert_equal(tf.round(max_val), max_val) + slot_label[i] = max_val + # [Batch_size, max_turns, 1] + labels_output = tf.argmax(tf.concat([labels_output, 1 - max_val], axis=2), axis=2) + # [Batch_size, max_turns] + val_label[i] = tf.cast(tf.expand_dims(labels_output, axis=2), dtype="float") + # [Batch_size, max_turns, 1] + + diag_memory = a_memory * tf.diag(tf.ones(num_values)) + non_diag_memory = tf.matrix_set_diag(b_memory * tf.ones([num_values, num_values]), tf.zeros(num_values)) + W_memory = diag_memory + non_diag_memory + + diag_current = a_current * tf.diag(tf.ones(num_values)) + non_diag_current = tf.matrix_set_diag(b_current * tf.ones([num_values, num_values]), tf.zeros(num_values)) + W_current = diag_current + non_diag_current + + diag_M_h = M_h_a * tf.diag(tf.ones(num_values)) + non_diag_M_h = tf.matrix_set_diag(M_h_b * tf.ones([num_values, num_values]), tf.zeros(num_values)) + M_h = diag_M_h + non_diag_M_h + + diag_U_m = U_m_a * tf.diag(tf.ones(num_values)) + non_diag_U_m = tf.matrix_set_diag(U_m_b * tf.ones([num_values, num_values]), tf.zeros(num_values)) + U_m = diag_U_m + non_diag_U_m + + diag_W_m = W_m_a * tf.diag(tf.ones(num_values)) + non_diag_W_m = tf.matrix_set_diag(W_m_b * tf.ones([num_values, num_values]), tf.zeros(num_values)) + W_m = diag_W_m + non_diag_W_m + + cell = GRU(W_memory, W_current, M_h, W_m, U_m, num_values) + y_predict, _ = tf.nn.dynamic_rnn(cell, slot_output, sequence_length=num_turns, dtype=tf.float32) + + y_predict = y_predict + 1000000.0 * domain_output + # [Batch_size, max_turns, num_values] + + y[i] = tf.nn.softmax(y_predict) + val_pred[i] = tf.cast(tf.expand_dims(tf.argmax(y[i], axis=2), axis=2), dtype="float32") + # [Batch_size, max_turns, 1] + y_pred[i] = tf.slice(tf.one_hot(tf.argmax(y[i], axis=2), dtype=tf.float32, depth=num_values), + begin=[0, 0, 0], size=[-1, -1, num_values - 1]) + y[i] = tf.slice(y[i], begin=[0, 0, 0], size=[-1, -1, num_values - 1]) + slot_pred[i] = tf.cast(tf.reduce_max(y_pred[i], axis=2, keep_dims=True), dtype="float32") + # [Batch_size, max_turns, 1] + loss[i] = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels_output, logits=y_predict) + # [Batch_size, max_turns] + + # ---------------- Compute the output and the loss function (cross_entropy) and add to optimizer-------------------- + cross_entropy = tf.add_n(loss, name="cross_entropy") + # Add the error from the domains + cross_entropy = tf.add(cross_entropy, domain_loss, name="total_loss") + + y = tf.concat(y, axis=2, name="y") + + mask = tf.cast(tf.sequence_mask(num_turns, maxlen=max_no_turns), dtype=tf.float32) + mask_extended = tf.tile(tf.expand_dims(mask, axis=2), [1, 1, num_slots]) + cross_entropy = tf.reduce_sum(mask * cross_entropy, axis=1) / tf.cast(num_turns, dtype=tf.float32) + + optimizer = tf.train.AdamOptimizer(0.001) + train_step = optimizer.minimize(cross_entropy, colocate_gradients_with_ops=True) + + # ----------------- Get the precision, recall f1-score and accuracy ----------------------------------------------- + + # Domain accuracy + true_predictions = tf.reshape(domain_labels, [-1, num_slots]) + predictions = tf.reshape(tf.round(y_d) * mask_extended, [-1, num_slots]) + + y_d = tf.reshape(y_d * mask_extended, [-1, num_slots]) + + _, _, _, domain_accuracy = get_metrics(predictions, true_predictions, num_turns, mask_extended, num_slots) + + mask_extended_2 = tf.tile(tf.expand_dims(mask, axis=2), [1, 1, len(slots)]) + + # Slot accuracy + true_predictions = tf.reshape(tf.concat(slot_label, axis=2), [-1, len(slots)]) + predictions = tf.reshape(tf.concat(slot_pred, axis=2) * mask_extended_2, [-1, len(slots)]) + + _, _, _, slot_accuracy = get_metrics(predictions, true_predictions, num_turns, mask_extended_2, len(slots)) + + # accuracy + if test: + value_accuracy = [] + mask_extended_3 = tf.expand_dims(mask, axis=2) + for i in range(len(slots)): + true_predictions = tf.reshape(val_label[i] * mask_extended_3, [-1, 1]) + predictions = tf.reshape(val_pred[i] * mask_extended_3, [-1, 1]) + + _, _, _, value_acc = get_metrics(predictions, true_predictions, num_turns, mask_extended_3, 1) + value_accuracy.append(value_acc) + + value_accuracy = tf.stack(value_accuracy) + else: + true_predictions = tf.reshape(tf.concat(val_label, axis=2) * mask_extended_2, [-1, len(slots)]) + predictions = tf.reshape(tf.concat(val_pred, axis=2) * mask_extended_2, [-1, len(slots)]) + + _, _, _, value_accuracy = get_metrics(predictions, true_predictions, num_turns, mask_extended_2, len(slots)) + + # Value f1score a + true_predictions = tf.reshape(labels, [-1, num_slots]) + predictions = tf.reshape(tf.concat(y_pred, axis=2) * mask_extended, [-1, num_slots]) + + precision, recall, value_f1_score, _ = get_metrics(predictions, true_predictions, num_turns, + mask_extended, num_slots) + + y_ = tf.reshape(y, [-1, num_slots]) + + # -------------------- Summarise the statistics of training to be viewed in tensorboard----------------------------- + tf.summary.scalar("domain_accuracy", domain_accuracy) + tf.summary.scalar("slot_accuracy", slot_accuracy) + tf.summary.scalar("value_accuracy", value_accuracy) + tf.summary.scalar("value_f1_score", value_f1_score) + tf.summary.scalar("cross_entropy", tf.reduce_mean(cross_entropy)) + + value_f1_score = [precision, recall, value_f1_score] + + return user_input, system_input, num_turns, user_utterance_lengths, sys_utterance_lengths, labels, domain_labels, \ + domain_accuracy, slot_accuracy, value_accuracy, value_f1_score, train_step, keep_prob, predictions, \ + true_predictions, [y_, y_d]
+ + +
[docs]def get_metrics(predictions, true_predictions, no_turns, mask, num_slots): + mask = tf.reshape(mask, [-1, num_slots]) + correct_prediction = tf.cast(tf.equal(predictions, true_predictions), "float32") * mask + + num_positives = tf.reduce_sum(true_predictions) + classified_positives = tf.reduce_sum(predictions) + + true_positives = tf.multiply(predictions, true_predictions) + num_true_positives = tf.reduce_sum(true_positives) + + recall = num_true_positives / num_positives + precision = num_true_positives / classified_positives + f_score = (2 * recall * precision) / (recall + precision) + accuracy = tf.reduce_sum(correct_prediction) / (tf.cast(tf.reduce_sum(no_turns), dtype="float32") * num_slots) + + return precision, recall, f_score, accuracy
+ + + +# main.py +
[docs]def normalise_word_vectors(word_vectors, norm=1.0): + """ + This method normalises the collection of word vectors provided in the word_vectors dictionary. + """ + for word in word_vectors: + word_vectors[word] /= math.sqrt(sum(word_vectors[word]**2) + 1e-6) + word_vectors[word] *= norm + return word_vectors
+ + +
[docs]def xavier_vector(word, D=300): + """ + Returns a D-dimensional vector for the word. + + We hash the word to always get the same vector for the given word. + """ + def hash_string(_s): + return abs(hash(_s)) % (10 ** 8) + seed_value = hash_string(word) + np.random.seed(seed_value) + + neg_value = - math.sqrt(6)/math.sqrt(D) + pos_value = math.sqrt(6)/math.sqrt(D) + + rsample = np.random.uniform(low=neg_value, high=pos_value, size=(D,)) + norm = np.linalg.norm(rsample) + rsample_normed = rsample/norm + + return rsample_normed
+ + +
[docs]def load_ontology(url, word_vectors): + ''' + Load the ontology from a file + :param url: to the ontology + :param word_vectors: dictionary of the word embeddings [words, vector_dimension] + :return: list([domain-slot-value]), [no_slots, vector_dimension] + ''' + global num_slots + # print("\tMDBT: Loading the ontology....................") + data = json.load(open(url, mode='r', encoding='utf8'), object_pairs_hook=OrderedDict) + slot_values = [] + ontology = [] + slots_values = [] + ontology_vectors = [] + for slots in data: + [domain, slot] = slots.split('-') + if domain not in domains or slot == 'name': + continue + values = data[slots] + if "book" in slot: + [slot, value] = slot.split(" ") + booking_slots[domain+'-'+value] = values + values = [value] + elif slot == "departure" or slot == "destination": + values = ["place"] + domain_vec = np.sum(process_text(domain, word_vectors), axis=0) + if domain not in word_vectors: + word_vectors[domain.replace(" ", "")] = domain_vec + slot_vec = np.sum(process_text(slot, word_vectors), axis=0) + if domain+'-'+slot not in slots_values: + slots_values.append(domain+'-'+slot) + if slot not in word_vectors: + word_vectors[slot.replace(" ", "")] = slot_vec + slot_values.append(len(values)) + for value in values: + ontology.append(domain + '-' + slot + '-' + value) + value_vec = np.sum(process_text(value, word_vectors, print_mode=True), axis=0) + if value not in word_vectors: + word_vectors[value.replace(" ", "")] = value_vec + ontology_vectors.append(np.concatenate((domain_vec, slot_vec, value_vec))) + + num_slots = len(slots_values) + # print("\tMDBT: We have about {} values".format(len(ontology))) + # print("\tMDBT: The Full Ontology is:") + # print(ontology) + # print("\tMDBT: The slots in this ontology:") + # print(slots_values) + return ontology, np.asarray(ontology_vectors, dtype='float32'), slot_values
+ + +
[docs]def load_word_vectors(url): + ''' + Load the word embeddings from the url + :param url: to the word vectors + :return: dict of word and vector values + ''' + word_vectors = {} + # print("Loading the word embeddings....................") + # print('abs path: ', os.path.abspath(url)) + with open(url, mode='r', encoding='utf8') as f: + for line in f: + line = line.split(" ", 1) + key = line[0] + word_vectors[key] = np.fromstring(line[1], dtype="float32", sep=" ") + # print("\tMDBT: The vocabulary contains about {} word embeddings".format(len(word_vectors))) + return normalise_word_vectors(word_vectors)
+ + +
[docs]def track_dialogue(data, ontology, predictions, y): + overall_accuracy_total = 0 + overall_accuracy_corr = 0 + joint_accuracy_total = 0 + joint_accuracy_corr = 0 + global num_slots + dialogues = [] + idx = 0 + for dialogue in data: + turn_ids = [] + for key in dialogue.keys(): + if key.isdigit(): + turn_ids.append(int(key)) + elif dialogue[key] and key not in domains: + continue + turn_ids.sort() + turns = [] + previous_terms = [] + for key in turn_ids: + turn = dialogue[str(key)] + user_input = turn['user']['text'] + sys_res = turn['system'] + state = turn['user']['belief_state'] + turn_obj = dict() + turn_obj['user'] = user_input + turn_obj['system'] = sys_res + prediction = predictions[idx, :] + indices = np.argsort(prediction)[:-(int(np.sum(prediction)) + 1):-1] + predicted_terms = [process_booking(ontology[i], user_input, previous_terms) for i in indices] + previous_terms = deepcopy(predicted_terms) + turn_obj['prediction'] = ["{}: {}".format(predicted_terms[x], y[idx, i]) for x, i in enumerate(indices)] + turn_obj['True state'] = [] + idx += 1 + unpredicted_labels = 0 + for domain in state: + if domain not in domains: + continue + slots = state[domain]['semi'] + for slot in slots: + if slot == 'name': + continue + value = slots[slot] + if value != '': + label = domain + '-' + slot + '-' + value + turn_obj['True state'].append(label) + if label in predicted_terms: + predicted_terms.remove(label) + else: + unpredicted_labels += 1 + + turns.append(turn_obj) + overall_accuracy_total += num_slots + overall_accuracy_corr += (num_slots - unpredicted_labels - len(predicted_terms)) + if unpredicted_labels + len(predicted_terms) == 0: + joint_accuracy_corr += 1 + joint_accuracy_total += 1 + + dialogues.append(turns) + return dialogues, overall_accuracy_corr/overall_accuracy_total, joint_accuracy_corr/joint_accuracy_total
+ + +
[docs]def process_booking(ontolog_term, usr_input, previous_terms): + usr_input = usr_input.lower().split() + domain, slot, value = ontolog_term.split('-') + if slot == 'book': + for term in previous_terms: + if domain+'-book '+value in term: + ontolog_term = term + break + else: + if value == 'stay' or value == 'people': + numbers = [int(s) for s in usr_input if s.isdigit()] + if len(numbers) == 1: + ontolog_term = domain + '-' + slot + ' ' + value + '-' + str(numbers[0]) + elif len(numbers) == 2: + vals = {} + if usr_input[usr_input.index(str(numbers[0]))+1] in ['people', 'person']: + vals['people'] = str(numbers[0]) + vals['stay'] = str(numbers[1]) + else: + vals['people'] = str(numbers[1]) + vals['stay'] = str(numbers[0]) + ontolog_term = domain + '-' + slot + ' ' + value + '-' + vals[value] + else: + for val in booking_slots[domain+'-'+value]: + if val in ' '.join(usr_input): + ontolog_term = domain + '-' + slot + ' ' + value + '-' + val + break + return ontolog_term
+ + +
[docs]def process_history(sessions, word_vectors, ontology): + ''' + Load the woz3 data and extract feature vectors + :param data: the data to load + :param word_vectors: word embeddings + :param ontology: list of domain-slot-value + :param url: Is the data coming from a url, default true + :return: list(num of turns, user_input vectors, system_response vectors, labels) + ''' + dialogues = [] + actual_dialogues = [] + for dialogue in sessions: + turn_ids = [] + for key in dialogue.keys(): + if key.isdigit(): + turn_ids.append(int(key)) + elif dialogue[key] and key not in domains: + continue + turn_ids.sort() + num_turns = len(turn_ids) + user_vecs = [] + sys_vecs = [] + turn_labels = [] + turn_domain_labels = [] + add = False + good = True + pre_sys = np.zeros([max_utterance_length, vector_dimension], dtype="float32") + for key in turn_ids: + turn = dialogue[str(key)] + user_v, sys_v, labels, domain_labels = process_turn(turn, word_vectors, ontology) + if good and (user_v.shape[0] > max_utterance_length or pre_sys.shape[0] > max_utterance_length): + # cut overlength utterance instead of discarding them + if user_v.shape[0] > max_utterance_length: + user_v = user_v[:max_utterance_length] + if pre_sys.shape[0] > max_utterance_length: + pre_sys = pre_sys[:max_utterance_length] + # good = False + # break + user_vecs.append(user_v) + sys_vecs.append(pre_sys) + turn_labels.append(labels) + turn_domain_labels.append(domain_labels) + if not add and sum(labels) > -1: + add = True + pre_sys = sys_v + if add and good: + dialogues.append((num_turns, user_vecs, sys_vecs, turn_labels, turn_domain_labels)) + actual_dialogues.append(dialogue) + # print("\tMDBT: The data contains about {} dialogues".format(len(dialogues))) + return dialogues, actual_dialogues
+ + +
[docs]def load_woz_data(data, word_vectors, ontology, url=True): + ''' + Load the woz3 data and extract feature vectors + :param data: the data to load + :param word_vectors: word embeddings + :param ontology: list of domain-slot-value + :param url: Is the data coming from a url, default true + :return: list(num of turns, user_input vectors, system_response vectors, labels) + ''' + if url: + # print("Loading data from url {} ....................".format(data)) + data = json.load(open(data, mode='r', encoding='utf8')) + + dialogues = [] + actual_dialogues = [] + for dialogue in data: + turn_ids = [] + for key in dialogue.keys(): + if key.isdigit(): + turn_ids.append(int(key)) + elif dialogue[key] and key not in domains: + continue + turn_ids.sort() + num_turns = len(turn_ids) + user_vecs = [] + sys_vecs = [] + turn_labels = [] + turn_domain_labels = [] + add = False + good = True + pre_sys = np.zeros([max_utterance_length, vector_dimension], dtype="float32") + for key in turn_ids: + turn = dialogue[str(key)] + user_v, sys_v, labels, domain_labels = process_turn(turn, word_vectors, ontology) + if good and (user_v.shape[0] > max_utterance_length or pre_sys.shape[0] > max_utterance_length): + good = False + break + user_vecs.append(user_v) + sys_vecs.append(pre_sys) + turn_labels.append(labels) + turn_domain_labels.append(domain_labels) + if not add and sum(labels) > 0: + add = True + pre_sys = sys_v + if add and good: + dialogues.append((num_turns, user_vecs, sys_vecs, turn_labels, turn_domain_labels)) + actual_dialogues.append(dialogue) + # print("\tMDBT: The data contains about {} dialogues".format(len(dialogues))) + return dialogues, actual_dialogues
+ + +
[docs]def process_turn(turn, word_vectors, ontology): + ''' + Process a single turn extracting and processing user text, system response and labels + :param turn: dict + :param word_vectors: word embeddings + :param ontology: list(domain-slot-value) + :return: ([utterance length, 300], [utterance length, 300], [no_slots]) + ''' + user_input = turn['user']['text'] + sys_res = turn['system'] + state = turn['user']['belief_state'] + user_v = process_text(user_input, word_vectors, ontology) + sys_v = process_text(sys_res, word_vectors, ontology) + labels = np.zeros(len(ontology), dtype='float32') + domain_labels = np.zeros(len(ontology), dtype='float32') + for domain in state: + if domain not in domains: + continue + slots = state[domain]['semi'] + domain_mention = False + for slot in slots: + + if slot == 'name': + continue + value = slots[slot] + if "book" in slot: + [slot, value] = slot.split(" ") + if value != '' and value != 'corsican': + if slot == "destination" or slot == "departure": + value = "place" + elif value == '09;45': + value = '09:45' + elif 'alpha-milton' in value: + value = value.replace('alpha-milton', 'alpha milton') + elif value == 'east side': + value = 'east' + elif value == ' expensive': + value = 'expensive' + labels[ontology.index(domain + '-' + slot + '-' + value)] = 1 + domain_mention = True + if domain_mention: + for idx, slot in enumerate(ontology): + if domain in slot: + domain_labels[idx] = 1 + + return user_v, sys_v, labels, domain_labels
+ + +
[docs]def process_text(text, word_vectors, ontology=None, print_mode=False): + ''' + Process a line/sentence converting words to feature vectors + :param text: sentence + :param word_vectors: word embeddings + :param ontology: The ontology to do exact matching + :param print_mode: Log the cases where the word is not in the pre-trained word vectors + :return: [length of sentence, 300] + ''' + text = text.replace("(", "").replace(")", "").replace('"', "").replace(u"’", "'").replace(u"‘", "'") + text = text.replace("\t", "").replace("\n", "").replace("\r", "").strip().lower() + text = text.replace(',', ' ').replace('.', ' ').replace('?', ' ').replace('-', ' ').replace('/', ' / ')\ + .replace(':', ' ') + if ontology: + for slot in ontology: + [domain, slot, value] = slot.split('-') + text.replace(domain, domain.replace(" ", ""))\ + .replace(slot, slot.replace(" ", ""))\ + .replace(value, value.replace(" ", "")) + + words = text.split() + + vectors = [] + for word in words: + word = word.replace("'", "").replace("!", "") + if word == "": + continue + if word not in word_vectors: + length = len(word) + for i in range(1, length)[::-1]: + if word[:i] in word_vectors and word[i:] in word_vectors: + vec = word_vectors[word[:i]] + word_vectors[word[i:]] + break + else: + vec = xavier_vector(word) + word_vectors[word] = vec + if print_mode: + pass + # print("\tMDBT: Adding new word: {}".format(word)) + else: + vec = word_vectors[word] + vectors.append(vec) + return np.asarray(vectors, dtype='float32')
+ + +
[docs]def generate_batch(dialogues, batch_no, batch_size, ontology_size): + ''' + Generate examples for minibatch training + :param dialogues: list(num of turns, user_input vectors, system_response vectors, labels) + :param batch_no: where we are in the training data + :param batch_size: number of dialogues to generate + :param ontology_size: no_slots + :return: list(user_input, system_response, labels, user_sentence_length, system_sentence_length, number of turns) + ''' + user = np.zeros((batch_size, max_no_turns, max_utterance_length, vector_dimension), dtype='float32') + sys_res = np.zeros((batch_size, max_no_turns, max_utterance_length, vector_dimension), dtype='float32') + labels = np.zeros((batch_size, max_no_turns, ontology_size), dtype='float32') + domain_labels = np.zeros((batch_size, max_no_turns, ontology_size), dtype='float32') + user_uttr_len = np.zeros((batch_size, max_no_turns), dtype='int32') + sys_uttr_len = np.zeros((batch_size, max_no_turns), dtype='int32') + no_turns = np.zeros(batch_size, dtype='int32') + idx = 0 + for i in range(batch_no*train_batch_size, batch_no*train_batch_size + batch_size): + (num_turns, user_vecs, sys_vecs, turn_labels, turn_domain_labels) = dialogues[i] + no_turns[idx] = num_turns + for j in range(num_turns): + user_uttr_len[idx, j] = user_vecs[j].shape[0] + sys_uttr_len[idx, j] = sys_vecs[j].shape[0] + user[idx, j, :user_uttr_len[idx, j], :] = user_vecs[j] + sys_res[idx, j, :sys_uttr_len[idx, j], :] = sys_vecs[j] + labels[idx, j, :] = turn_labels[j] + domain_labels[idx, j, :] = turn_domain_labels[j] + idx += 1 + return user, sys_res, labels, domain_labels, user_uttr_len, sys_uttr_len, no_turns
+ + +
[docs]def evaluate_model(sess, model_variables, val_data, summary, batch_id, i): + + ''' + Evaluate the model against validation set + :param sess: training session + :param model_variables: all model input variables + :param val_data: validation data + :param summary: For tensorboard + :param batch_id: where we are in the training data + :param i: the index of the validation data to load + :return: evaluation accuracy and the summary + ''' + + (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy, + slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, _, _, _) = model_variables + + batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \ + batch_no_turns = val_data + + start_time = time.time() + + b_z = train_batch_size + [precision, recall, value_f1] = value_f1 + [d_acc, s_acc, v_acc, f1_score, pr, re, sm1, sm2] = sess.run([domain_accuracy, slot_accuracy, value_accuracy, + value_f1, precision, recall] + summary, + feed_dict={user: batch_user[i:i+b_z, :, :, :], + sys_res: batch_sys[i:i+b_z, :, :, :], + labels: batch_labels[i:i+b_z, :, :], + domain_labels: batch_domain_labels[i:i+b_z, :, :], + user_uttr_len: batch_user_uttr_len[i:i+b_z, :], + sys_uttr_len: batch_sys_uttr_len[i:i+b_z, :], + no_turns: batch_no_turns[i:i+b_z], + keep_prob: 1.0}) + + print("Batch", batch_id, "[Domain Accuracy] = ", d_acc, "[Slot Accuracy] = ", s_acc, "[Value Accuracy] = ", + v_acc, "[F1 Score] = ", f1_score, "[Precision] = ", pr, "[Recall] = ", re, + " ----- ", round(time.time() - start_time, 3), + "seconds. ---") + + return d_acc, s_acc, v_acc, f1_score, sm1, sm2
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/model/model.html b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/model/model.html new file mode 100644 index 0000000..274709b --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/model/model.html @@ -0,0 +1,776 @@ + + + + + + + + + + + convlab.modules.word_policy.multiwoz.mdrg.model.model — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_policy.multiwoz.mdrg.model.model
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_policy.multiwoz.mdrg.model.model

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+from __future__ import division, print_function, unicode_literals
+
+import json
+import math
+import operator
+import os
+import random
+from functools import reduce
+from io import open
+# from Queue import PriorityQueue
+from queue import PriorityQueue
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import optim
+
+from convlab.modules.word_policy.multiwoz.mdrg.model import policy
+
+SOS_token = 0
+EOS_token = 1
+UNK_token = 2
+PAD_token = 3
+
+
+# Shawn beam search decoding
+
[docs]class BeamSearchNode(object): + def __init__(self, h, prevNode, wordid, logp, leng): + self.h = h + self.prevNode = prevNode + self.wordid = wordid + self.logp = logp + self.leng = leng + +
[docs] def eval(self, repeatPenalty, tokenReward, scoreTable, alpha=1.0): + reward = 0 + alpha = 1.0 + + return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
+ + +
[docs]def init_lstm(cell, gain=1): + init_gru(cell, gain) + + # positive forget gate bias (Jozefowicz et al., 2015) + for _, _, ih_b, hh_b in cell.all_weights: + l = len(ih_b) + ih_b[l // 4:l // 2].data.fill_(1.0) + hh_b[l // 4:l // 2].data.fill_(1.0)
+ + +
[docs]def init_gru(gru, gain=1): + gru.reset_parameters() + for _, hh, _, _ in gru.all_weights: + for i in range(0, hh.size(0), gru.hidden_size): + torch.nn.init.orthogonal_(hh[i:i+gru.hidden_size],gain=gain)
+ + +
[docs]def whatCellType(input_size, hidden_size, cell_type, dropout_rate): + if cell_type == 'rnn': + cell = nn.RNN(input_size, hidden_size, dropout=dropout_rate, batch_first=False) + init_gru(cell) + return cell + elif cell_type == 'gru': + cell = nn.GRU(input_size, hidden_size, dropout=dropout_rate, batch_first=False) + init_gru(cell) + return cell + elif cell_type == 'lstm': + cell = nn.LSTM(input_size, hidden_size, dropout=dropout_rate, batch_first=False) + init_lstm(cell) + return cell + elif cell_type == 'bigru': + cell = nn.GRU(input_size, hidden_size, bidirectional=True, dropout=dropout_rate, batch_first=False) + init_gru(cell) + return cell + elif cell_type == 'bilstm': + cell = nn.LSTM(input_size, hidden_size, bidirectional=True, dropout=dropout_rate, batch_first=False) + init_lstm(cell) + return cell
+ + +
[docs]class EncoderRNN(nn.Module): + def __init__(self, input_size, embedding_size, hidden_size, cell_type, depth, dropout): + super(EncoderRNN, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.embed_size = embedding_size + self.n_layers = depth + self.dropout = dropout + self.bidirectional = False + if 'bi' in cell_type: + self.bidirectional = True + padding_idx = 3 + self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=padding_idx) + self.rnn = whatCellType(embedding_size, hidden_size, + cell_type, dropout_rate=self.dropout) + +
[docs] def forward(self, input_seqs, input_lens, hidden=None): + """ + forward procedure. **No need for inputs to be sorted** + :param input_seqs: Variable of [T,B] + :param hidden: + :param input_lens: *numpy array* of len for each input sequence + :return: + """ + input_lens = np.asarray(input_lens) + input_seqs = input_seqs.transpose(0,1) + #batch_size = input_seqs.size(1) + embedded = self.embedding(input_seqs) + embedded = embedded.transpose(0, 1) # [B,T,E] + sort_idx = np.argsort(-input_lens) + unsort_idx = torch.LongTensor(np.argsort(sort_idx)) + input_lens = input_lens[sort_idx] + sort_idx = torch.LongTensor(sort_idx) + embedded = embedded[sort_idx].transpose(0, 1) # [T,B,E] + packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lens) + outputs, hidden = self.rnn(packed, hidden) + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + if self.bidirectional: + outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] + + outputs = outputs.transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() + + if isinstance(hidden, tuple): + hidden = list(hidden) + hidden[0] = hidden[0].transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() + hidden[1] = hidden[1].transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() + hidden = tuple(hidden) + else: + hidden = hidden.transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() + + return outputs, hidden
+ + +
[docs]class Attn(nn.Module): + def __init__(self, method, hidden_size): + super(Attn, self).__init__() + self.method = method + self.hidden_size = hidden_size + self.attn = nn.Linear(self.hidden_size * 2, hidden_size) + self.v = nn.Parameter(torch.rand(hidden_size)) + stdv = 1. / math.sqrt(self.v.size(0)) + self.v.data.normal_(mean=0, std=stdv) + +
[docs] def forward(self, hidden, encoder_outputs): + ''' + :param hidden: + previous hidden state of the decoder, in shape (layers*directions,B,H) + :param encoder_outputs: + encoder outputs from Encoder, in shape (T,B,H) + :return + attention energies in shape (B,T) + ''' + max_len = encoder_outputs.size(0) + + H = hidden.repeat(max_len,1,1).transpose(0,1) + encoder_outputs = encoder_outputs.transpose(0,1) # [T,B,H] -> [B,T,H] + attn_energies = self.score(H,encoder_outputs) # compute attention score + return F.softmax(attn_energies, dim=1).unsqueeze(1) # normalize with softmax
+ +
[docs] def score(self, hidden, encoder_outputs): + cat = torch.cat([hidden, encoder_outputs], 2) + energy = torch.tanh(self.attn(cat)) # [B*T*2H]->[B*T*H] + energy = energy.transpose(2,1) # [B*H*T] + v = self.v.repeat(encoder_outputs.data.shape[0],1).unsqueeze(1) #[B*1*H] + energy = torch.bmm(v,energy) # [B*1*T] + return energy.squeeze(1) # [B*T]
+ + +
[docs]class SeqAttnDecoderRNN(nn.Module): + def __init__(self, embedding_size, hidden_size, output_size, cell_type, dropout_p=0.1, max_length=30): + super(SeqAttnDecoderRNN, self).__init__() + # Define parameters + self.hidden_size = hidden_size + self.embed_size = embedding_size + self.output_size = output_size + self.n_layers = 1 + self.dropout_p = dropout_p + + # Define layers + self.embedding = nn.Embedding(output_size, embedding_size) + self.dropout = nn.Dropout(dropout_p) + + if 'bi' in cell_type: # we dont need bidirectionality in decoding + cell_type = cell_type.strip('bi') + self.rnn = whatCellType(embedding_size + hidden_size, hidden_size, cell_type, dropout_rate=self.dropout_p) + self.out = nn.Linear(hidden_size, output_size) + + self.score = nn.Linear(self.hidden_size + self.hidden_size, self.hidden_size) + self.attn_combine = nn.Linear(embedding_size + hidden_size, embedding_size) + + # attention + self.method = 'concat' + self.attn = nn.Linear(self.hidden_size * 2, hidden_size) + self.v = nn.Parameter(torch.rand(hidden_size)) + stdv = 1. / math.sqrt(self.v.size(0)) + self.v.data.normal_(mean=0, std=stdv) + +
[docs] def forward(self, input, hidden, encoder_outputs): + if isinstance(hidden, tuple): + h_t = hidden[0] + else: + h_t = hidden + encoder_outputs = encoder_outputs.transpose(0, 1) + embedded = self.embedding(input) # .view(1, 1, -1) + # embedded = F.dropout(embedded, self.dropout_p) + + # SCORE 3 + max_len = encoder_outputs.size(1) + h_t = h_t.transpose(0, 1) # [1,B,D] -> [B,1,D] + h_t = h_t.repeat(1, max_len, 1) # [B,1,D] -> [B,T,D] + energy = self.attn(torch.cat((h_t, encoder_outputs), 2)) # [B,T,2D] -> [B,T,D] + energy = torch.tanh(energy) + energy = energy.transpose(2, 1) # [B,H,T] + v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1) # [B,1,H] + energy = torch.bmm(v, energy) # [B,1,T] + attn_weights = F.softmax(energy, dim=2) # [B,1,T] + + # getting context + context = torch.bmm(attn_weights, encoder_outputs) # [B,1,H] + + # context = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)) #[B,1,H] + # Combine embedded input word and attended context, run through RNN + rnn_input = torch.cat((embedded, context), 2) + rnn_input = rnn_input.transpose(0, 1) + output, hidden = self.rnn(rnn_input, hidden) + output = output.squeeze(0) # (1,B,V)->(B,V) + + output = F.log_softmax(self.out(output), dim=1) + return output, hidden # , attn_weights
+ + +
[docs]class DecoderRNN(nn.Module): + def __init__(self, embedding_size, hidden_size, output_size, cell_type, dropout=0.1): + super(DecoderRNN, self).__init__() + self.hidden_size = hidden_size + self.cell_type = cell_type + padding_idx = 3 + self.embedding = nn.Embedding(num_embeddings=output_size, + embedding_dim=embedding_size, + padding_idx=padding_idx + ) + if 'bi' in cell_type: # we dont need bidirectionality in decoding + cell_type = cell_type.strip('bi') + self.rnn = whatCellType(embedding_size, hidden_size, cell_type, dropout_rate=dropout) + self.dropout_rate = dropout + self.out = nn.Linear(hidden_size, output_size) + +
[docs] def forward(self, input, hidden, not_used): + embedded = self.embedding(input).transpose(0, 1) # [B,1] -> [ 1,B, D] + embedded = F.dropout(embedded, self.dropout_rate) + + output = embedded + #output = F.relu(embedded) + + output, hidden = self.rnn(output, hidden) + + out = self.out(output.squeeze(0)) + output = F.log_softmax(out, dim=1) + + return output, hidden
+ + +
[docs]class Model(nn.Module): + def __init__(self, args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index): + super(Model, self).__init__() + self.args = args + self.max_len = args.max_len + + self.output_lang_index2word = output_lang_index2word + self.input_lang_index2word = input_lang_index2word + + self.output_lang_word2index = output_lang_word2index + self.input_lang_word2index = input_lang_word2index + + self.hid_size_enc = args.hid_size_enc + self.hid_size_dec = args.hid_size_dec + self.hid_size_pol = args.hid_size_pol + + self.emb_size = args.emb_size + self.db_size = args.db_size + self.bs_size = args.bs_size + self.cell_type = args.cell_type + if 'bi' in self.cell_type: + self.num_directions = 2 + else: + self.num_directions = 1 + self.depth = args.depth + self.use_attn = args.use_attn + self.attn_type = args.attention_type + + self.dropout = args.dropout + self.device = torch.device("cuda" if args.cuda else "cpu") + + self.model_dir = args.model_dir + self.model_name = args.model_name + self.teacher_forcing_ratio = args.teacher_ratio + self.vocab_size = args.vocab_size + self.epsln = 10E-5 + + + torch.manual_seed(args.seed) + self.build_model() + self.getCount() + try: + assert self.args.beam_width > 0 + self.beam_search = True + except: + self.beam_search = False + + self.global_step = 0 + +
[docs] def cuda_(self, var): + return var.cuda() if self.args.cuda else var
+ +
[docs] def build_model(self): + self.encoder = EncoderRNN(len(self.input_lang_index2word), self.emb_size, self.hid_size_enc, + self.cell_type, self.depth, self.dropout).to(self.device) + + self.policy = policy.DefaultPolicy(self.hid_size_pol, self.hid_size_enc, self.db_size, self.bs_size).to(self.device) + + if self.use_attn: + if self.attn_type == 'bahdanau': + self.decoder = SeqAttnDecoderRNN(self.emb_size, self.hid_size_dec, len(self.output_lang_index2word), self.cell_type, self.dropout, self.max_len).to(self.device) + else: + self.decoder = DecoderRNN(self.emb_size, self.hid_size_dec, len(self.output_lang_index2word), self.cell_type, self.dropout).to(self.device) + + if self.args.mode == 'train': + self.gen_criterion = nn.NLLLoss(ignore_index=3, size_average=True) # logsoftmax is done in decoder part + self.setOptimizers()
+ +
[docs] def train(self, input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor, dial_name=None): + proba, _, decoded_sent = self.forward(input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor) + + proba = proba.view(-1, self.vocab_size) + self.gen_loss = self.gen_criterion(proba, target_tensor.view(-1)) + + self.loss = self.gen_loss + self.loss.backward() + grad = self.clipGradients() + self.optimizer.step() + self.optimizer.zero_grad() + + #self.printGrad() + return self.loss.item(), 0, grad
+ +
[docs] def setOptimizers(self): + self.optimizer_policy = None + if self.args.optim == 'sgd': + self.optimizer = optim.SGD(lr=self.args.lr_rate, params=filter(lambda x: x.requires_grad, self.parameters()), weight_decay=self.args.l2_norm) + elif self.args.optim == 'adadelta': + self.optimizer = optim.Adadelta(lr=self.args.lr_rate, params=filter(lambda x: x.requires_grad, self.parameters()), weight_decay=self.args.l2_norm) + elif self.args.optim == 'adam': + self.optimizer = optim.Adam(lr=self.args.lr_rate, params=filter(lambda x: x.requires_grad, self.parameters()), weight_decay=self.args.l2_norm)
+ +
[docs] def forward(self, input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor): + """Given the user sentence, user belief state and database pointer, + encode the sentence, decide what policy vector construct and + feed it as the first hiddent state to the decoder.""" + target_length = target_tensor.size(1) + + # for fixed encoding this is zero so it does not contribute + batch_size, seq_len = input_tensor.size() + + # ENCODER + encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths) + + # POLICY + decoder_hidden = self.policy(encoder_hidden, db_tensor, bs_tensor) + + # GENERATOR + # Teacher forcing: Feed the target as the next input + _, target_len = target_tensor.size() + decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=self.device) + + proba = torch.zeros(batch_size, target_length, self.vocab_size) # [B,T,V] + + for t in range(target_len): + decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs) + + use_teacher_forcing = True if random.random() < self.args.teacher_ratio else False + if use_teacher_forcing: + decoder_input = target_tensor[:, t].view(-1, 1) # [B,1] Teacher forcing + else: + # Without teacher forcing: use its own predictions as the next input + topv, topi = decoder_output.topk(1) + decoder_input = topi.squeeze().detach() # detach from history as input + + proba[:, t, :] = decoder_output + + decoded_sent = None + + return proba, None, decoded_sent
+ +
[docs] def predict(self, input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor): + with torch.no_grad(): + # ENCODER + encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths) + + # POLICY + decoder_hidden = self.policy(encoder_hidden, db_tensor, bs_tensor) + + # GENERATION + decoded_words = self.decode(target_tensor, decoder_hidden, encoder_outputs) + + return decoded_words, 0
+ +
[docs] def decode(self, target_tensor, decoder_hidden, encoder_outputs): + decoder_hiddens = decoder_hidden + + if self.beam_search: # wenqiang style - sequicity + decoded_sentences = [] + for idx in range(target_tensor.size(0)): + if isinstance(decoder_hiddens, tuple): # LSTM case + decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0)) + else: + decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0) + encoder_output = encoder_outputs[:,idx, :].unsqueeze(1) + + # Beam start + self.topk = 1 + endnodes = [] # stored end nodes + number_required = min((self.topk + 1), self.topk - len(endnodes)) + decoder_input = torch.LongTensor([[SOS_token]], device=self.device) + + # starting node hidden vector, prevNode, wordid, logp, leng, + node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1) + nodes = PriorityQueue() # start the queue + nodes.put((-node.eval(None, None, None, None), + node)) + + # start beam search + qsize = 1 + while True: + # give up when decoding takes too long + if qsize > 2000: break + + # fetch the best node + score, n = nodes.get() + decoder_input = n.wordid + decoder_hidden = n.h + + if n.wordid.item() == EOS_token and n.prevNode != None: # its not empty + endnodes.append((score, n)) + # if reach maximum # of sentences required + if len(endnodes) >= number_required: + break + else: + continue + + # decode for one step using decoder + decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_output) + + log_prob, indexes = torch.topk(decoder_output, self.args.beam_width) + nextnodes = [] + + for new_k in range(self.args.beam_width): + decoded_t = indexes[0][new_k].view(1, -1) + log_p = log_prob[0][new_k].item() + + node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1) + score = -node.eval(None, None, None, None) + nextnodes.append((score, node)) + + # put them into queue + for i in range(len(nextnodes)): + score, nn = nextnodes[i] + nodes.put((score, nn)) + + # increase qsize + qsize += len(nextnodes) + + # choose nbest paths, back trace them + if len(endnodes) == 0: + endnodes = [nodes.get() for n in range(self.topk)] + + utterances = [] + for score, n in sorted(endnodes, key=operator.itemgetter(0)): + utterance = [] + utterance.append(n.wordid) + # back trace + while n.prevNode != None: + n = n.prevNode + utterance.append(n.wordid) + + utterance = utterance[::-1] + utterances.append(utterance) + + decoded_words = utterances[0] + decoded_sentence = [self.output_index2word(str(ind.item())) for ind in decoded_words] + #print(decoded_sentence) + decoded_sentences.append(' '.join(decoded_sentence[1:-1])) + + return decoded_sentences + + else: # GREEDY DECODING + decoded_sentences = self.greedy_decode(decoder_hidden, encoder_outputs, target_tensor) + return decoded_sentences
+ +
[docs] def greedy_decode(self, decoder_hidden, encoder_outputs, target_tensor): + decoded_sentences = [] + batch_size, seq_len = target_tensor.size() + decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=self.device) + + decoded_words = torch.zeros((batch_size, self.max_len)) + for t in range(self.max_len): + decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs) + + topv, topi = decoder_output.data.topk(1) # get candidates + topi = topi.view(-1) + + decoded_words[:, t] = topi + decoder_input = topi.detach().view(-1, 1) + + for sentence in decoded_words: + sent = [] + for ind in sentence: + if self.output_index2word(str(int(ind.item()))) == self.output_index2word(str(EOS_token)): + break + sent.append(self.output_index2word(str(int(ind.item())))) + decoded_sentences.append(' '.join(sent)) + + return decoded_sentences
+ +
[docs] def clipGradients(self): + grad = torch.nn.utils.clip_grad_norm_(self.parameters(), self.args.clip) + return grad
+ +
[docs] def saveModel(self, iter): + print('Saving parameters..') + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + + torch.save(self.encoder.state_dict(), self.model_dir + self.model_name + '-' + str(iter) + '.enc') + torch.save(self.policy.state_dict(), self.model_dir + self.model_name + '-' + str(iter) + '.pol') + torch.save(self.decoder.state_dict(), self.model_dir + self.model_name + '-' + str(iter) + '.dec') + + with open(self.model_dir + self.model_name + '.config', 'w') as f: + # f.write(unicode(json.dumps(vars(self.args), ensure_ascii=False, indent=4))) + f.write(json.dumps(vars(self.args), ensure_ascii=False, indent=4))
+ +
[docs] def loadModel(self, model_file=None, iter=0): + print('Loading parameters of iter %s ' % iter) + # self.encoder.load_state_dict(torch.load(self.model_dir + self.model_name + '-' + str(iter) + '.enc')) + # self.policy.load_state_dict(torch.load(self.model_dir + self.model_name + '-' + str(iter) + '.pol')) + # self.decoder.load_state_dict(torch.load(self.model_dir + self.model_name + '-' + str(iter) + '.dec')) + self.encoder.load_state_dict(torch.load(model_file + '.enc')) + self.policy.load_state_dict(torch.load(model_file + '.pol')) + self.decoder.load_state_dict(torch.load(model_file + '.dec'))
+ +
[docs] def input_index2word(self, index): + # if self.input_lang_index2word.has_key(index): + if index in self.input_lang_index2word: + return self.input_lang_index2word[index] + else: + raise UserWarning('We are using UNK')
+ +
[docs] def output_index2word(self, index): + # if self.output_lang_index2word.has_key(index): + if index in self.output_lang_index2word: + return self.output_lang_index2word[index] + else: + raise UserWarning('We are using UNK')
+ +
[docs] def input_word2index(self, index): + # if self.input_lang_word2index.has_key(index): + if index in self.input_lang_word2index: + return self.input_lang_word2index[index] + else: + return 2
+ +
[docs] def output_word2index(self, index): + # if self.output_lang_word2index.has_key(index): + if index in self.output_lang_word2index: + return self.output_lang_word2index[index] + else: + return 2
+ +
[docs] def getCount(self): + learnable_parameters = filter(lambda p: p.requires_grad, self.parameters()) + param_cnt = sum([reduce((lambda x, y: x * y), param.shape) for param in learnable_parameters]) + print('Model has', param_cnt, ' parameters.')
+ +
[docs] def printGrad(self): + learnable_parameters = filter(lambda p: p.requires_grad, self.parameters()) + for idx, param in enumerate(learnable_parameters): + print(param.grad, param.shape)
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/model/policy.html b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/model/policy.html new file mode 100644 index 0000000..24c2c12 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/model/policy.html @@ -0,0 +1,217 @@ + + + + + + + + + + + convlab.modules.word_policy.multiwoz.mdrg.model.policy — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_policy.multiwoz.mdrg.model.policy
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_policy.multiwoz.mdrg.model.policy

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import torch
+import torch.nn as nn
+
+
+
[docs]class DefaultPolicy(nn.Module): + def __init__(self, hidden_size_pol, hidden_size, db_size, bs_size): + super(DefaultPolicy, self).__init__() + self.hidden_size = hidden_size + + + self.W_u = nn.Linear(hidden_size, hidden_size_pol, bias=False) + self.W_bs = nn.Linear(bs_size, hidden_size_pol, bias=False) + self.W_db = nn.Linear(db_size, hidden_size_pol, bias=False) + +
[docs] def forward(self, encodings, db_tensor, bs_tensor, act_tensor=None): + if isinstance(encodings, tuple): + hidden = encodings[0] + else: + hidden = encodings + + # Network based + output = self.W_u(hidden[0]) + self.W_db(db_tensor) + self.W_bs(bs_tensor) + output = torch.tanh(output) + + if isinstance(encodings, tuple): # return LSTM tuple + return (output.unsqueeze(0), encodings[1]) + else: + return output.unsqueeze(0)
+ +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/policy.html b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/policy.html new file mode 100644 index 0000000..ccdd928 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/policy.html @@ -0,0 +1,862 @@ + + + + + + + + + + + convlab.modules.word_policy.multiwoz.mdrg.policy — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_policy.multiwoz.mdrg.policy
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_policy.multiwoz.mdrg.policy

+#!/usr/bin/env python
+# coding: utf-8
+
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from __future__ import division, print_function, unicode_literals
+
+import json
+import os
+import pickle
+import re
+import shutil
+import tempfile
+import time
+import zipfile
+from copy import deepcopy
+
+import numpy as np
+import torch
+
+from convlab.lib.file_util import cached_path
+from convlab.modules.dst.multiwoz.dst_util import init_state
+from convlab.modules.policy.system.policy import SysPolicy
+from convlab.modules.word_policy.multiwoz.mdrg.model import Model
+from convlab.modules.word_policy.multiwoz.mdrg.utils import util, dbquery, delexicalize
+from convlab.modules.word_policy.multiwoz.mdrg.utils.nlp import normalize
+
+DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), 'data/nrg/mdrg')
+
+
[docs]class Args(object): + pass
+args = Args() +args.no_cuda = True +args.seed = 1 +args.no_models = 20 +args.original = DATA_PATH #'/home/sule/projects/research/multiwoz/model/model/' +args.dropout = 0. +args.use_emb = 'False' +args.beam_width = 10 +args.write_n_best = False +args.model_path = os.path.join(DATA_PATH, 'translate.ckpt') +args.model_dir = DATA_PATH + '/' #'/home/sule/projects/research/multiwoz/model/model/' +args.model_name = 'translate.ckpt' +args.valid_output = 'val_dials' #'/home/sule/projects/research/multiwoz/model/data/val_dials/' +args.test_output = 'test_dials' #'/home/sule/projects/research/multiwoz/model/data/test_dials/' + +args.batch_size=64 +args.vocab_size=400 + +args.use_attn=False +args.attention_type='bahdanau' +args.use_emb=False + +args.emb_size=50 +args.hid_size_enc=150 +args.hid_size_dec=150 +args.hid_size_pol=150 +args.db_size=30 +args.bs_size=94 + +args.cell_type='lstm' +args.depth=1 +args.max_len=50 + +args.optim='adam' +args.lr_rate=0.005 +args.lr_decay=0.0 +args.l2_norm=0.00001 +args.clip=5.0 + +args.teacher_ratio=1.0 +args.dropout=0.0 + +args.no_cuda=True + +args.seed=0 +args.train_output='train_dials' #'data/train_dials/' + +args.max_epochs=15 +args.early_stop_count=2 + +args.load_param=False +args.epoch_load=0 + +args.mode='test' + +args.cuda = not args.no_cuda and torch.cuda.is_available() + +torch.manual_seed(args.seed) + +device = torch.device("cuda" if args.cuda else "cpu") + + +
[docs]def load_config(args): + config = util.unicode_to_utf8( + json.load(open('%s.json' % args.model_path, 'rb'))) + for key, value in args.__args.items(): + try: + config[key] = value.value + except: + config[key] = value + + return config
+ + +
[docs]def addBookingPointer(state, pointer_vector): + """Add information about availability of the booking option.""" + # Booking pointer + rest_vec = np.array([1, 0]) + if "book" in state['restaurant']: + if "booked" in state['restaurant']['book']: + if state['restaurant']['book']["booked"]: + if "reference" in state['restaurant']['book']["booked"][0]: + rest_vec = np.array([0, 1]) + + hotel_vec = np.array([1, 0]) + if "book" in state['hotel']: + if "booked" in state['hotel']['book']: + if state['hotel']['book']["booked"]: + if "reference" in state['hotel']['book']["booked"][0]: + hotel_vec = np.array([0, 1]) + + train_vec = np.array([1, 0]) + if "book" in state['train']: + if "booked" in state['train']['book']: + if state['train']['book']["booked"]: + if "reference" in state['train']['book']["booked"][0]: + train_vec = np.array([0, 1]) + + pointer_vector = np.append(pointer_vector, rest_vec) + pointer_vector = np.append(pointer_vector, hotel_vec) + pointer_vector = np.append(pointer_vector, train_vec) + + # pprint(pointer_vector) + return pointer_vector
+ + +
[docs]def addDBPointer(state): + """Create database pointer for all related domains.""" + domains = ['restaurant', 'hotel', 'attraction', 'train'] + pointer_vector = np.zeros(6 * len(domains)) + db_results = {} + num_entities = {} + for domain in domains: + # entities = dbPointer.queryResultVenues(domain, {'metadata': state}) + entities = dbquery.query(domain, state[domain]['semi'].items()) + num_entities[domain] = len(entities) + if len(entities) > 0: + # fields = dbPointer.table_schema(domain) + # db_results[domain] = dict(zip(fields, entities[0])) + db_results[domain] = entities[0] + # pointer_vector = dbPointer.oneHotVector(len(entities), domain, pointer_vector) + pointer_vector = oneHotVector(len(entities), domain, pointer_vector) + + return pointer_vector, db_results, num_entities
+ +
[docs]def oneHotVector(num, domain, vector): + """Return number of available entities for particular domain.""" + domains = ['restaurant', 'hotel', 'attraction', 'train'] + number_of_options = 6 + if domain != 'train': + idx = domains.index(domain) + if num == 0: + vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0,0]) + elif num == 1: + vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) + elif num == 2: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) + elif num == 3: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) + elif num == 4: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) + elif num >= 5: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) + else: + idx = domains.index(domain) + if num == 0: + vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0]) + elif num <= 2: + vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) + elif num <= 5: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) + elif num <= 10: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) + elif num <= 40: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) + elif num > 40: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) + + return vector
+ + +
[docs]def delexicaliseReferenceNumber(sent, state): + """Based on the belief state, we can find reference number that + during data gathering was created randomly.""" + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'] # , 'police'] + for domain in domains: + if state[domain]['book']['booked']: + for slot in state[domain]['book']['booked'][0]: + if slot == 'reference': + val = '[' + domain + '_' + slot + ']' + else: + val = '[' + domain + '_' + slot + ']' + key = normalize(state[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + + # try reference with hashtag + key = normalize("#" + state[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + + # try reference with ref# + key = normalize("ref#" + state[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + return sent
+ + +
[docs]def get_summary_bstate(bstate): + """Based on the mturk annotations we form multi-domain belief state""" + domains = ['taxi', 'restaurant', 'hospital', 'hotel', 'attraction', 'train', 'police'] + summary_bstate = [] + for domain in domains: + domain_active = False + + booking = [] + #print(domain,len(bstate[domain]['book'].keys())) + for slot in sorted(bstate[domain]['book'].keys()): + if slot == 'booked': + if bstate[domain]['book']['booked']: + booking.append(1) + else: + booking.append(0) + else: + if bstate[domain]['book'][slot] != "": + booking.append(1) + else: + booking.append(0) + if domain == 'train': + if 'people' not in bstate[domain]['book'].keys(): + booking.append(0) + else: + booking.append(1) + if 'ticket' not in bstate[domain]['book'].keys(): + booking.append(0) + else: + booking.append(1) + summary_bstate += booking + + for slot in bstate[domain]['semi']: + slot_enc = [0, 0, 0] # not mentioned, dontcare, filled + if bstate[domain]['semi'][slot] == 'not mentioned': + slot_enc[0] = 1 + elif bstate[domain]['semi'][slot] == 'dont care' or bstate[domain]['semi'][slot] == 'dontcare' or bstate[domain]['semi'][slot] == "don't care": + slot_enc[1] = 1 + elif bstate[domain]['semi'][slot]: + slot_enc[2] = 1 + if slot_enc != [0, 0, 0]: + domain_active = True + summary_bstate += slot_enc + + # quasi domain-tracker + if domain_active: + summary_bstate += [1] + else: + summary_bstate += [0] + + # pprint(summary_bstate) + #print(len(summary_bstate)) + assert len(summary_bstate) == 94 + return summary_bstate
+ + +
[docs]def populate_template(template, top_results, num_results, state): + active_domain = None if len(top_results.keys()) == 0 else list(top_results.keys())[0] + template = template.replace('book [value_count] of them', 'book one of them') + tokens = template.split() + response = [] + for token in tokens: + if token.startswith('[') and token.endswith(']'): + domain = token[1:-1].split('_')[0] + slot = token[1:-1].split('_')[1] + if domain == 'train' and slot == 'id': + slot = 'trainID' + if domain in top_results and len(top_results[domain]) > 0 and slot in top_results[domain]: + # print('{} -> {}'.format(token, top_results[domain][slot])) + response.append(top_results[domain][slot]) + elif domain == 'value': + if slot == 'count': + response.append(str(num_results)) + elif slot == 'place': + if 'arrive' in response: + for d in state: + if d == 'history': + continue + if 'destination' in state[d]['semi']: + response.append(state[d]['semi']['destination']) + break + elif 'leave' in response: + for d in state: + if d == 'history': + continue + if 'departure' in state[d]['semi']: + response.append(state[d]['semi']['departure']) + break + else: + try: + for d in state: + if d == 'history': + continue + for s in ['destination', 'departure']: + if s in state[d]['semi']: + response.append(state[d]['semi'][s]) + raise + except: + pass + else: + response.append(token) + elif slot == 'time': + if 'arrive' in ' '.join(response[-3:]): + if active_domain is not None and 'arriveBy' in top_results[active_domain]: + # print('{} -> {}'.format(token, top_results[active_domain]['arriveBy'])) + response.append(top_results[active_domain]['arriveBy']) + continue + for d in state: + if d == 'history': + continue + if 'arriveBy' in state[d]['semi']: + response.append(state[d]['semi']['arriveBy']) + break + elif 'leave' in ' '.join(response[-3:]): + if active_domain is not None and 'leaveAt' in top_results[active_domain]: + # print('{} -> {}'.format(token, top_results[active_domain]['leaveAt'])) + response.append(top_results[active_domain]['leaveAt']) + continue + for d in state: + if d == 'history': + continue + if 'leaveAt' in state[d]['semi']: + response.append(state[d]['semi']['leaveAt']) + break + elif 'book' in response: + if state['restaurant']['book']['time'] != "": + response.append(state['restaurant']['book']['time']) + else: + try: + for d in state: + if d == 'history': + continue + for s in ['arriveBy', 'leaveAt']: + if s in state[d]['semi']: + response.append(state[d]['semi'][s]) + raise + except: + pass + else: + response.append(token) + else: + # slot-filling based on query results + for d in top_results: + if slot in top_results[d]: + response.append(top_results[d][slot]) + break + else: + # slot-filling based on belief state + for d in state: + if d == 'history': + continue + if slot in state[d]['semi']: + response.append(state[d]['semi'][slot]) + break + else: + response.append(token) + else: + if domain == 'hospital': + if slot == 'phone': + response.append('01223216297') + elif slot == 'department': + response.append('neurosciences critical care unit') + elif domain == 'police': + if slot == 'phone': + response.append('01223358966') + elif slot == 'name': + response.append('Parkside Police Station') + elif slot == 'address': + response.append('Parkside, Cambridge') + elif domain == 'taxi': + if slot == 'phone': + response.append('01223358966') + elif slot == 'color': + response.append('white') + elif slot == 'type': + response.append('toyota') + else: + # print(token) + response.append(token) + else: + response.append(token) + + try: + response = ' '.join(response) + except Exception as e: + # pprint(response) + raise + response = response.replace(' -s', 's') + response = response.replace(' -ly', 'ly') + response = response.replace(' .', '.') + response = response.replace(' ?', '?') + return response
+ + +
[docs]def mark_not_mentioned(state): + for domain in state: + # if domain == 'history': + if domain not in ['police', 'hospital', 'taxi', 'train', 'attraction', 'restaurant', 'hotel']: + continue + try: + # if len([s for s in state[domain]['semi'] if s != 'book' and state[domain]['semi'][s] != '']) > 0: + # for s in state[domain]['semi']: + # if s != 'book' and state[domain]['semi'][s] == '': + # state[domain]['semi'][s] = 'not mentioned' + for s in state[domain]['semi']: + if state[domain]['semi'][s] == '': + state[domain]['semi'][s] = 'not mentioned' + except Exception as e: + # print(str(e)) + # pprint(state[domain]) + pass
+ + +
[docs]def predict(model, prev_state, prev_active_domain, state, dic): + start_time = time.time() + model.beam_search = False + input_tensor = []; bs_tensor = []; db_tensor = [] + + usr = state['history'][-1][-1] + + prev_state = deepcopy(prev_state['belief_state']) + state = deepcopy(state['belief_state']) + + mark_not_mentioned(prev_state) + mark_not_mentioned(state) + + words = usr.split() + usr = delexicalize.delexicalise(' '.join(words), dic) + + # parsing reference number GIVEN belief state + usr = delexicaliseReferenceNumber(usr, state) + + # changes to numbers only here + digitpat = re.compile('\d+') + usr = re.sub(digitpat, '[value_count]', usr) + # dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts) + + # add database pointer + pointer_vector, top_results, num_results = addDBPointer(state) + # add booking pointer + pointer_vector = addBookingPointer(state, pointer_vector) + belief_summary = get_summary_bstate(state) + + tensor = [model.input_word2index(word) for word in normalize(usr).strip(' ').split(' ')] + [util.EOS_token] + input_tensor.append(torch.LongTensor(tensor)) + bs_tensor.append(belief_summary) # + db_tensor.append(pointer_vector) # db results and booking + # bs_tensor.append([0.] * 94) # + # db_tensor.append([0.] * 30) # db results and booking + # create an empty matrix with padding tokens + input_tensor, input_lengths = util.padSequence(input_tensor) + bs_tensor = torch.tensor(bs_tensor, dtype=torch.float, device=device) + db_tensor = torch.tensor(db_tensor, dtype=torch.float, device=device) + + output_words, loss_sentence = model.predict(input_tensor, input_lengths, input_tensor, input_lengths, + db_tensor, bs_tensor) + active_domain = get_active_domain(prev_active_domain, prev_state, state) + if active_domain is not None and active_domain in num_results: + num_results = num_results[active_domain] + else: + num_results = 0 + if active_domain is not None and active_domain in top_results: + top_results = {active_domain: top_results[active_domain]} + else: + top_results = {} + response = populate_template(output_words[0], top_results, num_results, state) + return response, active_domain
+ + +
[docs]def get_active_domain(prev_active_domain, prev_state, state): + domains = ['hotel', 'restaurant', 'attraction', 'train', 'taxi', 'hospital', 'police'] + active_domain = None + # print('get_active_domain') + for domain in domains: + if domain not in prev_state and domain not in state: + continue + if domain in prev_state and domain not in state: + return domain + elif domain not in prev_state and domain in state: + return domain + elif prev_state[domain] != state[domain]: + active_domain = domain + if active_domain is None: + active_domain = prev_active_domain + return active_domain
+ + +
[docs]def loadModel(num): + # Load dictionaries + with open(os.path.join(DATA_PATH, 'input_lang.index2word.json')) as f: + input_lang_index2word = json.load(f) + with open(os.path.join(DATA_PATH, 'input_lang.word2index.json')) as f: + input_lang_word2index = json.load(f) + with open(os.path.join(DATA_PATH, 'output_lang.index2word.json')) as f: + output_lang_index2word = json.load(f) + with open(os.path.join(DATA_PATH, 'output_lang.word2index.json')) as f: + output_lang_word2index = json.load(f) + + # Reload existing checkpoint + model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index) + model.loadModel(iter=num) + + return model
+ +DEFAULT_CUDA_DEVICE = -1 +DEFAULT_DIRECTORY = "models" +DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "milu.tar.gz") + +
[docs]class MDRGWordPolicy(SysPolicy): + def __init__(self, + archive_file=DEFAULT_ARCHIVE_FILE, + cuda_device=DEFAULT_CUDA_DEVICE, + model_file=None): + + if not os.path.isfile(archive_file): + if not model_file: + raise Exception("No model for MDRG is specified!") + archive_file = cached_path(model_file) + + temp_path = tempfile.mkdtemp() + zip_ref = zipfile.ZipFile(archive_file, 'r') + zip_ref.extractall(temp_path) + zip_ref.close() + + self.dic = pickle.load(open(os.path.join(temp_path, 'mdrg/svdic.pkl'), 'rb')) + # Load dictionaries + with open(os.path.join(temp_path, 'mdrg/input_lang.index2word.json')) as f: + input_lang_index2word = json.load(f) + with open(os.path.join(temp_path, 'mdrg/input_lang.word2index.json')) as f: + input_lang_word2index = json.load(f) + with open(os.path.join(temp_path, 'mdrg/output_lang.index2word.json')) as f: + output_lang_index2word = json.load(f) + with open(os.path.join(temp_path, 'mdrg/output_lang.word2index.json')) as f: + output_lang_word2index = json.load(f) + self.response_model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index) + self.response_model.loadModel(os.path.join(temp_path, 'mdrg/mdrg')) + + shutil.rmtree(temp_path) + + self.prev_state = init_state() + self.prev_active_domain = None + +
[docs] def predict(self, state): + try: + response, active_domain = predict(self.response_model, self.prev_state, self.prev_active_domain, state, self.dic) + except Exception as e: + print('Response generation error', e) + response = 'What did you say?' + active_domain = None + self.prev_state = deepcopy(state) + self.prev_active_domain = active_domain + return response
+ + + +if __name__ == '__main__': + dic = pickle.load(open(os.path.join(DATA_PATH, 'svdic.pkl'), 'rb')) + state = { + "police": { + "book": { + "booked": [] + }, + "semi": {} + }, + "hotel": { + "book": { + "booked": [], + "people": "", + "day": "", + "stay": "" + }, + "semi": { + "name": "", + "area": "", + "parking": "", + "pricerange": "", + "stars": "", + "internet": "", + "type": "" + } + }, + "attraction": { + "book": { + "booked": [] + }, + "semi": { + "type": "", + "name": "", + "area": "" + } + }, + "restaurant": { + "book": { + "booked": [], + "people": "", + "day": "", + "time": "" + }, + "semi": { + "food": "", + "price range": "", + "name": "", + "area": "", + } + }, + "hospital": { + "book": { + "booked": [] + }, + "semi": { + "department": "" + } + }, + "taxi": { + "book": { + "booked": [] + }, + "semi": { + "leaveAt": "", + "destination": "", + "departure": "", + "arriveBy": "" + } + }, + "train": { + "book": { + "booked": [], + "people": "" + }, + "semi": { + "leaveAt": "", + "destination": "", + "day": "", + "arriveBy": "", + "departure": "" + } + } + } + + m = loadModel(15) + + # modify state + s = deepcopy(state) + s['history'] = [['null', 'I want a korean restaurant in the centre.']] + s['attraction']['semi']['area'] = 'centre' + s['restaurant']['semi']['area'] = 'centre' + s['restaurant']['semi']['food'] = 'korean' + # s['history'] = [['null', 'i need to book a hotel in the east that has 4 stars.']] + # s['hotel']['semi']['area'] = 'east' + # s['hotel']['semi']['stars'] = '4' + predict(m, state, s, dic) + + # import requests + # resp = requests.post('http://localhost:10001', json={'history': [['null', 'I want a korean restaurant in the centre.']]}) + # if resp.status_code != 200: + # # raise Exception('POST /tasks/ {}'.format(resp.status_code)) + # response = "Sorry, there is some problem" + # else: + # response = resp.json()["response"] + # print('Response: {}'.format(response)) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/dbquery.html b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/dbquery.html new file mode 100644 index 0000000..344a0e5 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/dbquery.html @@ -0,0 +1,244 @@ + + + + + + + + + + + convlab.modules.word_policy.multiwoz.mdrg.utils.dbquery — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_policy.multiwoz.mdrg.utils.dbquery
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_policy.multiwoz.mdrg.utils.dbquery

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+import json
+import os
+import random
+
+# loading databases
+domains = ['restaurant', 'hotel', 'attraction', 'train', 'hospital', 'taxi', 'police']
+dbs = {}
+for domain in domains:
+    dbs[domain] = json.load(open(os.path.join(
+        os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 
+        'db/{}_db.json'.format(domain))))
+
+
[docs]def query(domain, constraints, ignore_open=True): + """Returns the list of entities for a given domain + based on the annotation of the belief state""" + # query the db + if domain == 'taxi': + return [{'taxi_colors': random.choice(dbs[domain]['taxi_colors']), + 'taxi_types': random.choice(dbs[domain]['taxi_types']), + 'taxi_phone': [random.randint(1, 9) for _ in range(10)]}] + if domain == 'police': + return dbs['police'] + if domain == 'hospital': + return dbs['hospital'] + + found = [] + for record in dbs[domain]: + for key, val in constraints: + if val == "" or val == "dont care" or val == 'not mentioned' or val == "don't care" or val == "dontcare" or val == "do n't care": + pass + else: + if key not in record: + continue + if key == 'leaveAt': + val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1]) + val2 = int(record['leaveAt'].split(':')[0]) * 100 + int(record['leaveAt'].split(':')[1]) + if val1 > val2: + break + elif key == 'arriveBy': + val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1]) + val2 = int(record['arriveBy'].split(':')[0]) * 100 + int(record['arriveBy'].split(':')[1]) + if val1 < val2: + break + # elif ignore_open and key in ['destination', 'departure', 'name']: + elif ignore_open and key in ['destination', 'departure']: + continue + else: + if val.strip() != record[key].strip(): + break + else: + found.append(record) + + return found
+ +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/delexicalize.html b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/delexicalize.html new file mode 100644 index 0000000..b67a829 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/delexicalize.html @@ -0,0 +1,339 @@ + + + + + + + + + + + convlab.modules.word_policy.multiwoz.mdrg.utils.delexicalize — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_policy.multiwoz.mdrg.utils.delexicalize
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_policy.multiwoz.mdrg.utils.delexicalize

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import pickle
+import re
+
+import simplejson as json
+
+from convlab.modules.word_policy.multiwoz.mdrg.utils.nlp import normalize
+
+digitpat = re.compile('\d+')
+timepat = re.compile("\d{1,2}[:]\d{1,2}")
+pricepat2 = re.compile("\d{1,3}[.]\d{1,2}")
+
+# FORMAT
+# domain_value
+# restaurant_postcode
+# restaurant_address
+# taxi_car8
+# taxi_number
+# train_id etc..
+
+
+
[docs]def prepareSlotValuesIndependent(): + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + dic = [] + dic_area = [] + dic_food = [] + dic_price = [] + + # read databases + for domain in domains: + try: + fin = open('db/' + domain + '_db.json') + db_json = json.load(fin) + fin.close() + + for ent in db_json: + for key, val in ent.items(): + if val == '?' or val == 'free': + pass + elif key == 'address': + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + if "road" in val: + val = val.replace("road", "rd") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "rd" in val: + val = val.replace("rd", "road") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "st" in val: + val = val.replace("st", "street") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "street" in val: + val = val.replace("street", "st") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif key == 'name': + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + if "b & b" in val: + val = val.replace("b & b", "bed and breakfast") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "bed and breakfast" in val: + val = val.replace("bed and breakfast", "b & b") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "hotel" in val and 'gonville' not in val: + val = val.replace("hotel", "") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "restaurant" in val: + val = val.replace("restaurant", "") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif key == 'postcode': + dic.append((normalize(val), '[' + domain + '_' + 'postcode' + ']')) + elif key == 'phone': + dic.append((val, '[' + domain + '_' + 'phone' + ']')) + elif key == 'trainID': + dic.append((normalize(val), '[' + domain + '_' + 'id' + ']')) + elif key == 'department': + dic.append((normalize(val), '[' + domain + '_' + 'department' + ']')) + + # NORMAL DELEX + elif key == 'area': + dic_area.append((normalize(val), '[' + 'value' + '_' + 'area' + ']')) + elif key == 'food': + dic_food.append((normalize(val), '[' + 'value' + '_' + 'food' + ']')) + elif key == 'pricerange': + dic_price.append((normalize(val), '[' + 'value' + '_' + 'pricerange' + ']')) + else: + pass + # TODO car type? + except: + pass + + if domain == 'hospital': + dic.append((normalize('Hills Rd'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('Hills Road'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('CB20QQ'), '[' + domain + '_' + 'postcode' + ']')) + dic.append(('01223245151', '[' + domain + '_' + 'phone' + ']')) + dic.append(('1223245151', '[' + domain + '_' + 'phone' + ']')) + dic.append(('0122324515', '[' + domain + '_' + 'phone' + ']')) + dic.append((normalize('Addenbrookes Hospital'), '[' + domain + '_' + 'name' + ']')) + + elif domain == 'police': + dic.append((normalize('Parkside'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('CB11JG'), '[' + domain + '_' + 'postcode' + ']')) + dic.append(('01223358966', '[' + domain + '_' + 'phone' + ']')) + dic.append(('1223358966', '[' + domain + '_' + 'phone' + ']')) + dic.append((normalize('Parkside Police Station'), '[' + domain + '_' + 'name' + ']')) + + # add at the end places from trains + # fin = file('db/' + 'train' + '_db.json') + fin = open('db/' + 'train' + '_db.json') + db_json = json.load(fin) + fin.close() + + for ent in db_json: + for key, val in ent.items(): + if key == 'departure' or key == 'destination': + dic.append((normalize(val), '[' + 'value' + '_' + 'place' + ']')) + + # add specific values: + for key in ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']: + dic.append((normalize(key), '[' + 'value' + '_' + 'day' + ']')) + + # more general values add at the end + dic.extend(dic_area) + dic.extend(dic_food) + dic.extend(dic_price) + + return dic
+ + +
[docs]def delexicalise(utt, dictionary): + for key, val in dictionary: + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + + return utt
+ + +
[docs]def delexicaliseDomain(utt, dictionary, domain): + for key, val in dictionary: + if key == domain or key == 'value': + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + + # go through rest of domain in case we are missing something out? + for key, val in dictionary: + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + return utt
+ +if __name__ == '__main__': + dic = prepareSlotValuesIndependent() + pickle.dump(dic, open('data/svdic.pkl', 'wb')) +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/nlp.html b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/nlp.html new file mode 100644 index 0000000..f387551 --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/nlp.html @@ -0,0 +1,440 @@ + + + + + + + + + + + convlab.modules.word_policy.multiwoz.mdrg.utils.nlp — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_policy.multiwoz.mdrg.utils.nlp
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_policy.multiwoz.mdrg.utils.nlp

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import math
+import os
+import re
+from collections import Counter
+
+from nltk.util import ngrams
+
+timepat = re.compile("\d{1,2}[:]\d{1,2}")
+pricepat = re.compile("\d{1,3}[.]\d{1,2}")
+
+
+# fin = file('utils/mapping.pair')
+# fin = open('/home/sule/projects/research/multiwoz/utils/mapping.pair')
+fin = open(os.path.join(os.path.dirname(os.path.abspath(__file__)),'mapping.pair'))
+replacements = []
+for line in fin.readlines():
+    tok_from, tok_to = line.replace('\n', '').split('\t')
+    replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))
+
+
+
[docs]def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text
+ + +
[docs]def normalize(text): + # lower case every word + text = text.lower() + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + # replace time and and price + text = re.sub(timepat, ' [value_time] ', text) + text = re.sub(pricepat, ' [value_price] ', text) + #text = re.sub(pricepat2, '[value_price]', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + text = re.sub('[\":\<>@\(\)]', '', text) + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + for fromx, tox in replacements: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tmp = text + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text
+ + +
[docs]class BLEUScorer(object): + ## BLEU score calculator via GentScorer interface + ## it calculates the BLEU-4 by taking the entire corpus in + ## Calulate based multiple candidates against multiple references + def __init__(self): + pass + +
[docs] def score(self, hypothesis, corpus, n=1): + # containers + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + + # accumulate ngram statistics + for hyps, refs in zip(hypothesis, corpus): + if isinstance(hyps[0], list): + hyps = [hyp.split() for hyp in hyps[0]] + else: + hyps = [hyp.split() for hyp in hyps] + + refs = [ref.split() for ref in refs] + + # Shawn's evaluation + refs[0] = [u'GO_'] + refs[0] + [u'EOS_'] + hyps[0] = [u'GO_'] + hyps[0] + [u'EOS_'] + + for idx, hyp in enumerate(hyps): + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + if n == 1: + break + # computing bleu score + p0 = 1e-7 + bp = 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ + for i in range(4)] + s = math.fsum(w * math.log(p_n) \ + for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu
+ + +
[docs]class GentScorer(object): + def __init__(self, detectfile): + self.bleuscorer = BLEUScorer() + +
[docs] def scoreBLEU(self, parallel_corpus): + return self.bleuscorer.score(parallel_corpus)
+ + +
[docs]def sentence_bleu_4(hyp, refs, weights=[0.25, 0.25, 0.25, 0.25]): + # input : single sentence, multiple references + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + + for i in range(4): + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: + break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r = bestmatch[1] + c = len(hyp) + + p0 = 1e-7 + bp = math.exp(-abs(1.0 - float(r) / float(c + p0))) + + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)] + s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) + bleu_hyp = bp * math.exp(s) + + return bleu_hyp
+ +if __name__ == '__main__': + text = "restaurant's CB39AL one seven" + text = "I'm I'd restaurant's CB39AL 099939399 one seven" + text = "ndd 19.30 nndd" + #print re.match("(\d+).(\d+)", text) + m = re.findall("(\d+\.\d+)", text) + print(m) + #print m[0].strip('.') + print(re.sub('\.', '', m[0])) + #print m.groups() + #print text +
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/util.html b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/util.html new file mode 100644 index 0000000..1ee0c2b --- /dev/null +++ b/docs/build/html/_modules/convlab/modules/word_policy/multiwoz/mdrg/utils/util.html @@ -0,0 +1,288 @@ + + + + + + + + + + + convlab.modules.word_policy.multiwoz.mdrg.utils.util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Module code »
  • + +
  • convlab.modules.word_policy.multiwoz.mdrg.utils.util
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

Source code for convlab.modules.word_policy.multiwoz.mdrg.utils.util

+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+'''
+Utility functions
+'''
+
+import argparse
+import json
+import math
+import pickle as pkl
+import sys
+import time
+
+import numpy as np
+import torch
+
+# DEFINE special tokens
+SOS_token = 0
+EOS_token = 1
+UNK_token = 2
+PAD_token = 3
+
+
+
[docs]def padSequence(tensor): + pad_token = PAD_token + tensor_lengths = [len(sentence) for sentence in tensor] + longest_sent = max(tensor_lengths) + batch_size = len(tensor) + padded_tensor = np.ones((batch_size, longest_sent)) * pad_token + + # copy over the actual sequences + for i, x_len in enumerate(tensor_lengths): + sequence = tensor[i] + padded_tensor[i, 0:x_len] = sequence[:x_len] + + padded_tensor = torch.LongTensor(padded_tensor) + return padded_tensor, tensor_lengths
+ + +
[docs]def loadDialogue(model, val_file, input_tensor, target_tensor, bs_tensor, db_tensor): + # Iterate over dialogue + for idx, (usr, sys, bs, db) in enumerate( + zip(val_file['usr'], val_file['sys'], val_file['bs'], val_file['db'])): + tensor = [model.input_word2index(word) for word in usr.strip(' ').split(' ')] + [ + EOS_token] # model.input_word2index(word) + input_tensor.append(torch.LongTensor(tensor)) # .view(-1, 1)) + + tensor = [model.output_word2index(word) for word in sys.strip(' ').split(' ')] + [EOS_token] + target_tensor.append(torch.LongTensor(tensor)) # .view(-1, 1) + + bs_tensor.append([float(belief) for belief in bs]) + db_tensor.append([float(pointer) for pointer in db]) + + return input_tensor, target_tensor, bs_tensor, db_tensor
+ + +#json loads strings as unicode; we currently still work with Python 2 strings, and need conversion +
[docs]def unicode_to_utf8(d): + return dict((key.encode("UTF-8"), value) for (key,value) in d.items())
+ + +
[docs]def load_dict(filename): + try: + with open(filename, 'rb') as f: + return unicode_to_utf8(json.load(f)) + except: + with open(filename, 'rb') as f: + return pkl.load(f)
+ + +
[docs]def load_config(basename): + try: + with open('%s.json' % basename, 'rb') as f: + return json.load(f) + except: + try: + with open('%s.pkl' % basename, 'rb') as f: + return pkl.load(f) + except: + sys.stderr.write('Error: config file {0}.json is missing\n'.format(basename)) + sys.exit(1)
+ + +
[docs]def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.')
+ + +
[docs]def asMinutes(s): + m = math.floor(s / 60) + s -= m * 60 + return '%dm %ds' % (m, s)
+ + +
[docs]def timeSince(since, percent): + now = time.time() + s = now - since + return '%s ' % (asMinutes(s))
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/convlab/spec/spec_util.html b/docs/build/html/_modules/convlab/spec/spec_util.html new file mode 100644 index 0000000..3065ae7 --- /dev/null +++ b/docs/build/html/_modules/convlab/spec/spec_util.html @@ -0,0 +1,447 @@ + + + + + + + + + + + convlab.spec.spec_util — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +

Source code for convlab.spec.spec_util

+# The spec module
+# Manages specification to run things in lab
+import itertools
+import json
+import os
+from string import Template
+
+import pydash as ps
+
+from convlab.lib import logger, util
+
+SPEC_DIR = 'convlab/spec'
+'''
+All spec values are already param, inferred automatically.
+To change from a value into param range, e.g.
+- single: "explore_anneal_epi": 50
+- continuous param: "explore_anneal_epi": {"min": 50, "max": 100, "dist": "uniform"}
+- discrete range: "explore_anneal_epi": {"values": [50, 75, 100]}
+'''
+SPEC_FORMAT = {
+    "agent": [{
+        "name": str,
+        "algorithm": dict,
+        # "memory": dict,
+        # "net": dict,
+    }],
+    "env": [{
+        "name": str,
+        "max_t": (type(None), int, float),
+        # "max_frame": (int, float),
+    }],
+    # "body": {
+    #     "product": ["outer", "inner", "custom"],
+    #     "num": (int, list),
+    # },
+    "meta": {
+        "eval_frequency": (int, float),
+        "max_session": int,
+        "max_trial": (type(None), int),
+    },
+    "name": str,
+}
+logger = logger.get_logger(__name__)
+
+
+
[docs]def check_comp_spec(comp_spec, comp_spec_format): + '''Base method to check component spec''' + for spec_k, spec_format_v in comp_spec_format.items(): + comp_spec_v = comp_spec[spec_k] + if ps.is_list(spec_format_v): + v_set = spec_format_v + assert comp_spec_v in v_set, f'Component spec value {ps.pick(comp_spec, spec_k)} needs to be one of {util.to_json(v_set)}' + else: + v_type = spec_format_v + assert isinstance(comp_spec_v, v_type), f'Component spec {ps.pick(comp_spec, spec_k)} needs to be of type: {v_type}' + if isinstance(v_type, tuple) and int in v_type and isinstance(comp_spec_v, float): + # cast if it can be int + comp_spec[spec_k] = int(comp_spec_v)
+ + +
[docs]def check_body_spec(spec): + '''Base method to check body spec for multi-agent multi-env''' + ae_product = ps.get(spec, 'body.product') + body_num = ps.get(spec, 'body.num') + if ae_product == 'outer': + pass + elif ae_product == 'inner': + agent_num = len(spec['agent']) + env_num = len(spec['env']) + assert agent_num == env_num, 'Agent and Env spec length must be equal for body `inner` product. Given {agent_num}, {env_num}' + else: # custom + assert ps.is_list(body_num)
+ + +
[docs]def check_compatibility(spec): + '''Check compatibility among spec setups''' + # TODO expand to be more comprehensive + if spec['meta'].get('distributed') == 'synced': + assert ps.get(spec, 'agent.0.net.gpu') == False, f'Distributed mode "synced" works with CPU only. Set gpu: false.'
+ + +
[docs]def check(spec): + '''Check a single spec for validity''' + try: + spec_name = spec.get('name') + assert set(spec.keys()) >= set(SPEC_FORMAT.keys()), f'Spec needs to follow spec.SPEC_FORMAT. Given \n {spec_name}: {util.to_json(spec)}' + for agent_spec in spec['agent']: + check_comp_spec(agent_spec, SPEC_FORMAT['agent'][0]) + for env_spec in spec['env']: + check_comp_spec(env_spec, SPEC_FORMAT['env'][0]) + # check_comp_spec(spec['body'], SPEC_FORMAT['body']) + check_comp_spec(spec['meta'], SPEC_FORMAT['meta']) + # check_body_spec(spec) + check_compatibility(spec) + except Exception as e: + logger.exception(f'spec {spec_name} fails spec check') + raise e + return True
+ + +
[docs]def check_all(): + '''Check all spec files, all specs.''' + spec_files = ps.filter_(os.listdir(SPEC_DIR), lambda f: f.endswith('.json') and not f.startswith('_')) + for spec_file in spec_files: + spec_dict = util.read(f'{SPEC_DIR}/{spec_file}') + for spec_name, spec in spec_dict.items(): + # fill-in info at runtime + spec['name'] = spec_name + spec = extend_meta_spec(spec) + try: + check(spec) + except Exception as e: + logger.exception(f'spec_file {spec_file} fails spec check') + raise e + logger.info(f'Checked all specs from: {ps.join(spec_files, ",")}') + return True
+ + +
[docs]def extend_meta_spec(spec): + '''Extend meta spec with information for lab functions''' + extended_meta_spec = { + # reset lab indices to -1 so that they tick to 0 + 'experiment': -1, + 'trial': -1, + 'session': -1, + 'cuda_offset': int(os.environ.get('CUDA_OFFSET', 0)), + 'experiment_ts': util.get_ts(), + 'prepath': None, + # ckpt extends prepath, e.g. ckpt_str = ckpt-epi10-totalt1000 + 'ckpt': None, + 'git_sha': util.get_git_sha(), + 'random_seed': None, + 'eval_model_prepath': None, + } + spec['meta'].update(extended_meta_spec) + return spec
+ + +
[docs]def get(spec_file, spec_name): + ''' + Get an experiment spec from spec_file, spec_name. + Auto-check spec. + @example + + spec = spec_util.get('base.json', 'base_case_openai') + ''' + spec_file = spec_file.replace(SPEC_DIR, '') # cleanup + if 'data/' in spec_file: + assert spec_name in spec_file, 'spec_file in data/ must be lab-generated and contains spec_name' + spec = util.read(spec_file) + else: + spec_file = f'{SPEC_DIR}/{spec_file}' # allow direct filename + spec_dict = util.read(spec_file) + assert spec_name in spec_dict, f'spec_name {spec_name} is not in spec_file {spec_file}. Choose from:\n {ps.join(spec_dict.keys(), ",")}' + spec = spec_dict[spec_name] + # fill-in info at runtime + spec['name'] = spec_name + spec = extend_meta_spec(spec) + check(spec) + return spec
+ + +
[docs]def get_eval_spec(spec_file, spec_name, prename=None): + '''Get spec for eval mode''' + spec = get(spec_file, spec_name) + spec['meta']['ckpt'] = 'eval' + spec['meta']['eval_model_prepath'] = prename + return spec
+ + +
[docs]def get_param_specs(spec): + '''Return a list of specs with substituted spec_params''' + assert 'spec_params' in spec, 'Parametrized spec needs a spec_params key' + spec_params = spec.pop('spec_params') + spec_template = Template(json.dumps(spec)) + keys = spec_params.keys() + specs = [] + for idx, vals in enumerate(itertools.product(*spec_params.values())): + spec_str = spec_template.substitute(dict(zip(keys, vals))) + spec = json.loads(spec_str) + spec['name'] += f'_{"_".join(vals)}' + # offset to prevent parallel-run GPU competition, to mod in util.set_cuda_id + cuda_id_gap = int(spec['meta']['max_session'] / spec['meta']['param_spec_process']) + spec['meta']['cuda_offset'] += idx * cuda_id_gap + specs.append(spec) + return specs
+ + +
[docs]def override_dev_spec(spec): + spec['meta']['max_session'] = 1 + spec['meta']['max_trial'] = 2 + return spec
+ + +#def override_enjoy_spec(spec): +# spec['meta']['max_session'] = 1 +# return spec + + +
[docs]def override_eval_spec(spec): + spec['meta']['max_session'] = 1 + # evaluate by episode is set in env clock init in env/base.py + return spec
+ + +
[docs]def override_test_spec(spec): + for agent_spec in spec['agent']: + # onpolicy freq is episodic + freq = 1 if agent_spec['memory']['name'] == 'OnPolicyReplay' else 8 + agent_spec['algorithm']['training_frequency'] = freq + agent_spec['algorithm']['training_start_step'] = 1 + agent_spec['algorithm']['training_iter'] = 1 + agent_spec['algorithm']['training_batch_iter'] = 1 + for env_spec in spec['env']: + env_spec['max_frame'] = 40 + env_spec['max_t'] = 12 + spec['meta']['log_frequency'] = 10 + spec['meta']['eval_frequency'] = 10 + spec['meta']['max_session'] = 1 + spec['meta']['max_trial'] = 2 + return spec
+ + +
[docs]def save(spec, unit='experiment'): + '''Save spec to proper path. Called at Experiment or Trial init.''' + prepath = util.get_prepath(spec, unit) + util.write(spec, f'{prepath}_spec.json')
+ + +
[docs]def tick(spec, unit): + ''' + Method to tick lab unit (experiment, trial, session) in meta spec to advance their indices + Reset lower lab indices to -1 so that they tick to 0 + spec_util.tick(spec, 'session') + session = Session(spec) + ''' + meta_spec = spec['meta'] + if unit == 'experiment': + meta_spec['experiment_ts'] = util.get_ts() + meta_spec['experiment'] += 1 + meta_spec['trial'] = -1 + meta_spec['session'] = -1 + elif unit == 'trial': + if meta_spec['experiment'] == -1: + meta_spec['experiment'] += 1 + meta_spec['trial'] += 1 + meta_spec['session'] = -1 + elif unit == 'session': + if meta_spec['experiment'] == -1: + meta_spec['experiment'] += 1 + if meta_spec['trial'] == -1: + meta_spec['trial'] += 1 + meta_spec['session'] += 1 + else: + raise ValueError(f'Unrecognized lab unit to tick: {unit}') + # set prepath since it is determined at this point + meta_spec['prepath'] = prepath = util.get_prepath(spec, unit) + for folder in ('graph', 'info', 'log', 'model'): + folder_prepath = util.insert_folder(prepath, folder) + os.makedirs(os.path.dirname(util.smart_path(folder_prepath)), exist_ok=True) + meta_spec[f'{folder}_prepath'] = folder_prepath + return spec
+
+ +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_modules/index.html b/docs/build/html/_modules/index.html new file mode 100644 index 0000000..bf2a7da --- /dev/null +++ b/docs/build/html/_modules/index.html @@ -0,0 +1,288 @@ + + + + + + + + + + + Overview: module code — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ +
    + +
  • Docs »
  • + +
  • Overview: module code
  • + + +
  • + +
  • + +
+ + +
+
+
+
+ +

All modules for which code is available

+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/_sources/convlab.agent.algorithm.rst.txt b/docs/build/html/_sources/convlab.agent.algorithm.rst.txt new file mode 100644 index 0000000..9a72945 --- /dev/null +++ b/docs/build/html/_sources/convlab.agent.algorithm.rst.txt @@ -0,0 +1,94 @@ +convlab.agent.algorithm package +=============================== + +Submodules +---------- + +convlab.agent.algorithm.actor\_critic module +-------------------------------------------- + +.. automodule:: convlab.agent.algorithm.actor_critic + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.base module +----------------------------------- + +.. automodule:: convlab.agent.algorithm.base + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.dqn module +---------------------------------- + +.. automodule:: convlab.agent.algorithm.dqn + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.external module +--------------------------------------- + +.. automodule:: convlab.agent.algorithm.external + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.policy\_util module +------------------------------------------- + +.. automodule:: convlab.agent.algorithm.policy_util + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.ppo module +---------------------------------- + +.. automodule:: convlab.agent.algorithm.ppo + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.random module +------------------------------------- + +.. automodule:: convlab.agent.algorithm.random + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.reinforce module +---------------------------------------- + +.. automodule:: convlab.agent.algorithm.reinforce + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.sarsa module +------------------------------------ + +.. automodule:: convlab.agent.algorithm.sarsa + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.algorithm.sil module +---------------------------------- + +.. automodule:: convlab.agent.algorithm.sil + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.agent.algorithm + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.agent.memory.rst.txt b/docs/build/html/_sources/convlab.agent.memory.rst.txt new file mode 100644 index 0000000..1bd5913 --- /dev/null +++ b/docs/build/html/_sources/convlab.agent.memory.rst.txt @@ -0,0 +1,46 @@ +convlab.agent.memory package +============================ + +Submodules +---------- + +convlab.agent.memory.base module +-------------------------------- + +.. automodule:: convlab.agent.memory.base + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.memory.onpolicy module +------------------------------------ + +.. automodule:: convlab.agent.memory.onpolicy + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.memory.prioritized module +--------------------------------------- + +.. automodule:: convlab.agent.memory.prioritized + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.memory.replay module +---------------------------------- + +.. automodule:: convlab.agent.memory.replay + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.agent.memory + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.agent.net.rst.txt b/docs/build/html/_sources/convlab.agent.net.rst.txt new file mode 100644 index 0000000..bdcb79b --- /dev/null +++ b/docs/build/html/_sources/convlab.agent.net.rst.txt @@ -0,0 +1,54 @@ +convlab.agent.net package +========================= + +Submodules +---------- + +convlab.agent.net.base module +----------------------------- + +.. automodule:: convlab.agent.net.base + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.net.conv module +----------------------------- + +.. automodule:: convlab.agent.net.conv + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.net.mlp module +---------------------------- + +.. automodule:: convlab.agent.net.mlp + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.net.net\_util module +---------------------------------- + +.. automodule:: convlab.agent.net.net_util + :members: + :undoc-members: + :show-inheritance: + +convlab.agent.net.recurrent module +---------------------------------- + +.. automodule:: convlab.agent.net.recurrent + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.agent.net + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.agent.rst.txt b/docs/build/html/_sources/convlab.agent.rst.txt new file mode 100644 index 0000000..9b0a839 --- /dev/null +++ b/docs/build/html/_sources/convlab.agent.rst.txt @@ -0,0 +1,19 @@ +convlab.agent package +===================== + +Subpackages +----------- + +.. toctree:: + + convlab.agent.algorithm + convlab.agent.memory + convlab.agent.net + +Module contents +--------------- + +.. automodule:: convlab.agent + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/tasktk.util.rst b/docs/build/html/_sources/convlab.env.rst.txt similarity index 52% rename from docs/tasktk.util.rst rename to docs/build/html/_sources/convlab.env.rst.txt index df4dbda..04271c6 100644 --- a/docs/tasktk.util.rst +++ b/docs/build/html/_sources/convlab.env.rst.txt @@ -1,29 +1,29 @@ -tasktk.util package +convlab.env package =================== Submodules ---------- -tasktk.util.dataloader module ------------------------------ +convlab.env.base module +----------------------- -.. automodule:: tasktk.util.dataloader +.. automodule:: convlab.env.base :members: :undoc-members: :show-inheritance: -tasktk.util.dialog\_act module ------------------------------- +convlab.env.movie module +------------------------ -.. automodule:: tasktk.util.dialog_act +.. automodule:: convlab.env.movie :members: :undoc-members: :show-inheritance: -tasktk.util.state module ------------------------- +convlab.env.multiwoz module +--------------------------- -.. automodule:: tasktk.util.state +.. automodule:: convlab.env.multiwoz :members: :undoc-members: :show-inheritance: @@ -32,7 +32,7 @@ tasktk.util.state module Module contents --------------- -.. automodule:: tasktk.util +.. automodule:: convlab.env :members: :undoc-members: :show-inheritance: diff --git a/docs/build/html/_sources/convlab.evaluator.rst.txt b/docs/build/html/_sources/convlab.evaluator.rst.txt new file mode 100644 index 0000000..0931c08 --- /dev/null +++ b/docs/build/html/_sources/convlab.evaluator.rst.txt @@ -0,0 +1,30 @@ +convlab.evaluator package +========================= + +Submodules +---------- + +convlab.evaluator.evaluator module +---------------------------------- + +.. automodule:: convlab.evaluator.evaluator + :members: + :undoc-members: + :show-inheritance: + +convlab.evaluator.multiwoz module +--------------------------------- + +.. automodule:: convlab.evaluator.multiwoz + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.evaluator + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.experiment.rst.txt b/docs/build/html/_sources/convlab.experiment.rst.txt new file mode 100644 index 0000000..a717b5b --- /dev/null +++ b/docs/build/html/_sources/convlab.experiment.rst.txt @@ -0,0 +1,46 @@ +convlab.experiment package +========================== + +Submodules +---------- + +convlab.experiment.analysis module +---------------------------------- + +.. automodule:: convlab.experiment.analysis + :members: + :undoc-members: + :show-inheritance: + +convlab.experiment.control module +--------------------------------- + +.. automodule:: convlab.experiment.control + :members: + :undoc-members: + :show-inheritance: + +convlab.experiment.retro\_analysis module +----------------------------------------- + +.. automodule:: convlab.experiment.retro_analysis + :members: + :undoc-members: + :show-inheritance: + +convlab.experiment.search module +-------------------------------- + +.. automodule:: convlab.experiment.search + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.experiment + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.human_eval.rst.txt b/docs/build/html/_sources/convlab.human_eval.rst.txt new file mode 100644 index 0000000..7f58d28 --- /dev/null +++ b/docs/build/html/_sources/convlab.human_eval.rst.txt @@ -0,0 +1,86 @@ +convlab.human\_eval package +=========================== + +Submodules +---------- + +convlab.human\_eval.analysis module +----------------------------------- + +.. automodule:: convlab.human_eval.analysis + :members: + :undoc-members: + :show-inheritance: + +convlab.human\_eval.bot\_server module +-------------------------------------- + +.. automodule:: convlab.human_eval.bot_server + :members: + :undoc-members: + :show-inheritance: + +convlab.human\_eval.cambot\_server module +----------------------------------------- + +.. automodule:: convlab.human_eval.cambot_server + :members: + :undoc-members: + :show-inheritance: + +convlab.human\_eval.dqnbot\_server module +----------------------------------------- + +.. automodule:: convlab.human_eval.dqnbot_server + :members: + :undoc-members: + :show-inheritance: + +convlab.human\_eval.rulebot\_server module +------------------------------------------ + +.. automodule:: convlab.human_eval.rulebot_server + :members: + :undoc-members: + :show-inheritance: + +convlab.human\_eval.run module +------------------------------ + +.. automodule:: convlab.human_eval.run + :members: + :undoc-members: + :show-inheritance: + +convlab.human\_eval.sequicity\_server module +-------------------------------------------- + +.. automodule:: convlab.human_eval.sequicity_server + :members: + :undoc-members: + :show-inheritance: + +convlab.human\_eval.task\_config module +--------------------------------------- + +.. automodule:: convlab.human_eval.task_config + :members: + :undoc-members: + :show-inheritance: + +convlab.human\_eval.worlds module +--------------------------------- + +.. automodule:: convlab.human_eval.worlds + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.human_eval + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.lib.rst.txt b/docs/build/html/_sources/convlab.lib.rst.txt new file mode 100644 index 0000000..de97c98 --- /dev/null +++ b/docs/build/html/_sources/convlab.lib.rst.txt @@ -0,0 +1,78 @@ +convlab.lib package +=================== + +Submodules +---------- + +convlab.lib.decorator module +---------------------------- + +.. automodule:: convlab.lib.decorator + :members: + :undoc-members: + :show-inheritance: + +convlab.lib.distribution module +------------------------------- + +.. automodule:: convlab.lib.distribution + :members: + :undoc-members: + :show-inheritance: + +convlab.lib.file\_util module +----------------------------- + +.. automodule:: convlab.lib.file_util + :members: + :undoc-members: + :show-inheritance: + +convlab.lib.logger module +------------------------- + +.. automodule:: convlab.lib.logger + :members: + :undoc-members: + :show-inheritance: + +convlab.lib.math\_util module +----------------------------- + +.. automodule:: convlab.lib.math_util + :members: + :undoc-members: + :show-inheritance: + +convlab.lib.optimizer module +---------------------------- + +.. automodule:: convlab.lib.optimizer + :members: + :undoc-members: + :show-inheritance: + +convlab.lib.util module +----------------------- + +.. automodule:: convlab.lib.util + :members: + :undoc-members: + :show-inheritance: + +convlab.lib.viz module +---------------------- + +.. automodule:: convlab.lib.viz + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.lib + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.action_decoder.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.action_decoder.multiwoz.rst.txt new file mode 100644 index 0000000..2a52acd --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.action_decoder.multiwoz.rst.txt @@ -0,0 +1,22 @@ +convlab.modules.action\_decoder.multiwoz package +================================================ + +Submodules +---------- + +convlab.modules.action\_decoder.multiwoz.multiwoz\_vocab\_action\_decoder module +-------------------------------------------------------------------------------- + +.. automodule:: convlab.modules.action_decoder.multiwoz.multiwoz_vocab_action_decoder + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.action_decoder.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.action_decoder.rst.txt b/docs/build/html/_sources/convlab.modules.action_decoder.rst.txt new file mode 100644 index 0000000..610f423 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.action_decoder.rst.txt @@ -0,0 +1,17 @@ +convlab.modules.action\_decoder package +======================================= + +Subpackages +----------- + +.. toctree:: + + convlab.modules.action_decoder.multiwoz + +Module contents +--------------- + +.. automodule:: convlab.modules.action_decoder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.dst.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.dst.multiwoz.rst.txt new file mode 100644 index 0000000..0d8c353 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.dst.multiwoz.rst.txt @@ -0,0 +1,38 @@ +convlab.modules.dst.multiwoz package +==================================== + +Submodules +---------- + +convlab.modules.dst.multiwoz.dst\_util module +--------------------------------------------- + +.. automodule:: convlab.modules.dst.multiwoz.dst_util + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.dst.multiwoz.evaluate module +-------------------------------------------- + +.. automodule:: convlab.modules.dst.multiwoz.evaluate + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.dst.multiwoz.rule\_dst module +--------------------------------------------- + +.. automodule:: convlab.modules.dst.multiwoz.rule_dst + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.dst.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.dst.rst.txt b/docs/build/html/_sources/convlab.modules.dst.rst.txt new file mode 100644 index 0000000..6e7d2cc --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.dst.rst.txt @@ -0,0 +1,29 @@ +convlab.modules.dst package +=========================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.dst.multiwoz + +Submodules +---------- + +convlab.modules.dst.state\_tracker module +----------------------------------------- + +.. automodule:: convlab.modules.dst.state_tracker + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.dst + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.models.rst.txt b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.models.rst.txt new file mode 100644 index 0000000..30e673e --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.models.rst.txt @@ -0,0 +1,54 @@ +convlab.modules.e2e.multiwoz.Mem2Seq.models package +=================================================== + +Submodules +---------- + +convlab.modules.e2e.multiwoz.Mem2Seq.models.Mem2Seq module +---------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.models.Mem2Seq + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.models.Mem2Seq\_NMT module +--------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.models.Mem2Seq_NMT + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.models.enc\_Luong module +------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.models.enc_Luong + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.models.enc\_PTRUNK module +-------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.models.enc_PTRUNK + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.models.enc\_vanilla module +--------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.models.enc_vanilla + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.models + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.rst.txt b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.rst.txt new file mode 100644 index 0000000..ffb8973 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.rst.txt @@ -0,0 +1,62 @@ +convlab.modules.e2e.multiwoz.Mem2Seq package +============================================ + +Subpackages +----------- + +.. toctree:: + + convlab.modules.e2e.multiwoz.Mem2Seq.models + convlab.modules.e2e.multiwoz.Mem2Seq.utils + +Submodules +---------- + +convlab.modules.e2e.multiwoz.Mem2Seq.Mem2Seq module +--------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.Mem2Seq + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.main\_interact module +---------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.main_interact + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.main\_nmt module +----------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.main_nmt + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.main\_test module +------------------------------------------------------ + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.main_test + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.main\_train module +------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.main_train + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.utils.rst.txt b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.utils.rst.txt new file mode 100644 index 0000000..16753b1 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Mem2Seq.utils.rst.txt @@ -0,0 +1,94 @@ +convlab.modules.e2e.multiwoz.Mem2Seq.utils package +================================================== + +Submodules +---------- + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.config module +-------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.config + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.masked\_cross\_entropy module +------------------------------------------------------------------------ + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.masked_cross_entropy + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.measures module +---------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.measures + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.until\_temp module +------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.until_temp + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils\_NMT module +------------------------------------------------------------ + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils_NMT + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils\_babi module +------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils_babi + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils\_babi\_mem2seq module +---------------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils_babi_mem2seq + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils\_kvr module +------------------------------------------------------------ + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils_kvr + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils\_kvr\_mem2seq module +--------------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils_kvr_mem2seq + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils\_woz\_mem2seq module +--------------------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils.utils_woz_mem2seq + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Mem2Seq.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Sequicity.rst.txt b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Sequicity.rst.txt new file mode 100644 index 0000000..83bae20 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.Sequicity.rst.txt @@ -0,0 +1,62 @@ +convlab.modules.e2e.multiwoz.Sequicity package +============================================== + +Submodules +---------- + +convlab.modules.e2e.multiwoz.Sequicity.Sequicity module +------------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Sequicity.Sequicity + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Sequicity.config module +---------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Sequicity.config + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Sequicity.metric module +---------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Sequicity.metric + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Sequicity.model module +--------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Sequicity.model + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Sequicity.reader module +---------------------------------------------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Sequicity.reader + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.e2e.multiwoz.Sequicity.tsd\_net module +------------------------------------------------------ + +.. automodule:: convlab.modules.e2e.multiwoz.Sequicity.tsd_net + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.e2e.multiwoz.Sequicity + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.e2e.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.rst.txt new file mode 100644 index 0000000..1ac27a4 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.e2e.multiwoz.rst.txt @@ -0,0 +1,18 @@ +convlab.modules.e2e.multiwoz package +==================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.e2e.multiwoz.Mem2Seq + convlab.modules.e2e.multiwoz.Sequicity + +Module contents +--------------- + +.. automodule:: convlab.modules.e2e.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.e2e.rst.txt b/docs/build/html/_sources/convlab.modules.e2e.rst.txt new file mode 100644 index 0000000..c5ede15 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.e2e.rst.txt @@ -0,0 +1,17 @@ +convlab.modules.e2e package +=========================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.e2e.multiwoz + +Module contents +--------------- + +.. automodule:: convlab.modules.e2e + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlg.multiwoz.multiwoz_template_nlg.rst.txt b/docs/build/html/_sources/convlab.modules.nlg.multiwoz.multiwoz_template_nlg.rst.txt new file mode 100644 index 0000000..1c21f7d --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlg.multiwoz.multiwoz_template_nlg.rst.txt @@ -0,0 +1,22 @@ +convlab.modules.nlg.multiwoz.multiwoz\_template\_nlg package +============================================================ + +Submodules +---------- + +convlab.modules.nlg.multiwoz.multiwoz\_template\_nlg.multiwoz\_template\_nlg module +----------------------------------------------------------------------------------- + +.. automodule:: convlab.modules.nlg.multiwoz.multiwoz_template_nlg.multiwoz_template_nlg + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlg.multiwoz.multiwoz_template_nlg + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlg.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.nlg.multiwoz.rst.txt new file mode 100644 index 0000000..611320f --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlg.multiwoz.rst.txt @@ -0,0 +1,46 @@ +convlab.modules.nlg.multiwoz package +==================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.nlg.multiwoz.multiwoz_template_nlg + convlab.modules.nlg.multiwoz.sc_lstm + +Submodules +---------- + +convlab.modules.nlg.multiwoz.evaluate module +-------------------------------------------- + +.. automodule:: convlab.modules.nlg.multiwoz.evaluate + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlg.multiwoz.template\_nlg module +------------------------------------------------- + +.. automodule:: convlab.modules.nlg.multiwoz.template_nlg + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlg.multiwoz.utils module +----------------------------------------- + +.. automodule:: convlab.modules.nlg.multiwoz.utils + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlg.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlg.multiwoz.sc_lstm.rst.txt b/docs/build/html/_sources/convlab.modules.nlg.multiwoz.sc_lstm.rst.txt new file mode 100644 index 0000000..252c4ae --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlg.multiwoz.sc_lstm.rst.txt @@ -0,0 +1,38 @@ +convlab.modules.nlg.multiwoz.sc\_lstm package +============================================= + +Submodules +---------- + +convlab.modules.nlg.multiwoz.sc\_lstm.bleu module +------------------------------------------------- + +.. automodule:: convlab.modules.nlg.multiwoz.sc_lstm.bleu + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlg.multiwoz.sc\_lstm.nlg\_sc\_lstm module +---------------------------------------------------------- + +.. automodule:: convlab.modules.nlg.multiwoz.sc_lstm.nlg_sc_lstm + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlg.multiwoz.sc\_lstm.run\_woz module +----------------------------------------------------- + +.. automodule:: convlab.modules.nlg.multiwoz.sc_lstm.run_woz + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlg.multiwoz.sc_lstm + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlg.rst.txt b/docs/build/html/_sources/convlab.modules.nlg.rst.txt new file mode 100644 index 0000000..26f0149 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlg.rst.txt @@ -0,0 +1,29 @@ +convlab.modules.nlg package +=========================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.nlg.multiwoz + +Submodules +---------- + +convlab.modules.nlg.nlg module +------------------------------ + +.. automodule:: convlab.modules.nlg.nlg + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlg + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlu.multiwoz.milu.rst.txt b/docs/build/html/_sources/convlab.modules.nlu.multiwoz.milu.rst.txt new file mode 100644 index 0000000..ab0fbaa --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlu.multiwoz.milu.rst.txt @@ -0,0 +1,70 @@ +convlab.modules.nlu.multiwoz.milu package +========================================= + +Submodules +---------- + +convlab.modules.nlu.multiwoz.milu.dai\_f1\_measure module +--------------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.milu.dai_f1_measure + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.milu.dataset\_reader module +-------------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.milu.dataset_reader + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.milu.evaluate module +------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.milu.evaluate + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.milu.model module +---------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.milu.model + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.milu.multilabel\_f1\_measure module +---------------------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.milu.multilabel_f1_measure + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.milu.nlu module +-------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.milu.nlu + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.milu.train module +---------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.milu.train + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlu.multiwoz.milu + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlu.multiwoz.onenet.rst.txt b/docs/build/html/_sources/convlab.modules.nlu.multiwoz.onenet.rst.txt new file mode 100644 index 0000000..4aabb90 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlu.multiwoz.onenet.rst.txt @@ -0,0 +1,62 @@ +convlab.modules.nlu.multiwoz.onenet package +=========================================== + +Submodules +---------- + +convlab.modules.nlu.multiwoz.onenet.dai\_f1\_measure module +----------------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.onenet.dai_f1_measure + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.onenet.dataset\_reader module +---------------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.onenet.dataset_reader + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.onenet.evaluate module +--------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.onenet.evaluate + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.onenet.model module +------------------------------------------------ + +.. automodule:: convlab.modules.nlu.multiwoz.onenet.model + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.onenet.nlu module +---------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.onenet.nlu + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.onenet.train module +------------------------------------------------ + +.. automodule:: convlab.modules.nlu.multiwoz.onenet.train + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlu.multiwoz.onenet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlu.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.nlu.multiwoz.rst.txt new file mode 100644 index 0000000..9d597b1 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlu.multiwoz.rst.txt @@ -0,0 +1,47 @@ +convlab.modules.nlu.multiwoz package +==================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.nlu.multiwoz.milu + convlab.modules.nlu.multiwoz.onenet + convlab.modules.nlu.multiwoz.svm + +Submodules +---------- + +convlab.modules.nlu.multiwoz.error module +----------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.error + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.evaluate module +-------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.evaluate + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.utils module +----------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.utils + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlu.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlu.multiwoz.svm.rst.txt b/docs/build/html/_sources/convlab.modules.nlu.multiwoz.svm.rst.txt new file mode 100644 index 0000000..6824b39 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlu.multiwoz.svm.rst.txt @@ -0,0 +1,70 @@ +convlab.modules.nlu.multiwoz.svm package +======================================== + +Submodules +---------- + +convlab.modules.nlu.multiwoz.svm.Classifier module +-------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.svm.Classifier + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.svm.Features module +------------------------------------------------ + +.. automodule:: convlab.modules.nlu.multiwoz.svm.Features + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.svm.Tuples module +---------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.svm.Tuples + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.svm.nlu module +------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.svm.nlu + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.svm.preprocess module +-------------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.svm.preprocess + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.svm.sutils module +---------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.svm.sutils + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.nlu.multiwoz.svm.train module +--------------------------------------------- + +.. automodule:: convlab.modules.nlu.multiwoz.svm.train + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlu.multiwoz.svm + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.nlu.rst.txt b/docs/build/html/_sources/convlab.modules.nlu.rst.txt new file mode 100644 index 0000000..71a9af4 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.nlu.rst.txt @@ -0,0 +1,29 @@ +convlab.modules.nlu package +=========================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.nlu.multiwoz + +Submodules +---------- + +convlab.modules.nlu.nlu module +------------------------------ + +.. automodule:: convlab.modules.nlu.nlu + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.nlu + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.policy.rst.txt b/docs/build/html/_sources/convlab.modules.policy.rst.txt new file mode 100644 index 0000000..b87403b --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.policy.rst.txt @@ -0,0 +1,18 @@ +convlab.modules.policy package +============================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.policy.system + convlab.modules.policy.user + +Module contents +--------------- + +.. automodule:: convlab.modules.policy + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.policy.system.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.policy.system.multiwoz.rst.txt new file mode 100644 index 0000000..f16ce26 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.policy.system.multiwoz.rst.txt @@ -0,0 +1,37 @@ +convlab.modules.policy.system.multiwoz package +============================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.policy.system.multiwoz.vanilla_mle + +Submodules +---------- + +convlab.modules.policy.system.multiwoz.rule\_based\_multiwoz\_bot module +------------------------------------------------------------------------ + +.. automodule:: convlab.modules.policy.system.multiwoz.rule_based_multiwoz_bot + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.policy.system.multiwoz.util module +-------------------------------------------------- + +.. automodule:: convlab.modules.policy.system.multiwoz.util + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.policy.system.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.policy.system.multiwoz.vanilla_mle.rst.txt b/docs/build/html/_sources/convlab.modules.policy.system.multiwoz.vanilla_mle.rst.txt new file mode 100644 index 0000000..7201a25 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.policy.system.multiwoz.vanilla_mle.rst.txt @@ -0,0 +1,54 @@ +convlab.modules.policy.system.multiwoz.vanilla\_mle package +=========================================================== + +Submodules +---------- + +convlab.modules.policy.system.multiwoz.vanilla\_mle.dataset\_reader module +-------------------------------------------------------------------------- + +.. automodule:: convlab.modules.policy.system.multiwoz.vanilla_mle.dataset_reader + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.policy.system.multiwoz.vanilla\_mle.evaluate module +------------------------------------------------------------------- + +.. automodule:: convlab.modules.policy.system.multiwoz.vanilla_mle.evaluate + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.policy.system.multiwoz.vanilla\_mle.model module +---------------------------------------------------------------- + +.. automodule:: convlab.modules.policy.system.multiwoz.vanilla_mle.model + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.policy.system.multiwoz.vanilla\_mle.policy module +----------------------------------------------------------------- + +.. automodule:: convlab.modules.policy.system.multiwoz.vanilla_mle.policy + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.policy.system.multiwoz.vanilla\_mle.train module +---------------------------------------------------------------- + +.. automodule:: convlab.modules.policy.system.multiwoz.vanilla_mle.train + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.policy.system.multiwoz.vanilla_mle + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.policy.system.rst.txt b/docs/build/html/_sources/convlab.modules.policy.system.rst.txt new file mode 100644 index 0000000..d7f0a2d --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.policy.system.rst.txt @@ -0,0 +1,29 @@ +convlab.modules.policy.system package +===================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.policy.system.multiwoz + +Submodules +---------- + +convlab.modules.policy.system.policy module +------------------------------------------- + +.. automodule:: convlab.modules.policy.system.policy + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.policy.system + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.policy.user.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.policy.user.multiwoz.rst.txt new file mode 100644 index 0000000..e444762 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.policy.user.multiwoz.rst.txt @@ -0,0 +1,30 @@ +convlab.modules.policy.user.multiwoz package +============================================ + +Submodules +---------- + +convlab.modules.policy.user.multiwoz.policy\_agenda\_multiwoz module +-------------------------------------------------------------------- + +.. automodule:: convlab.modules.policy.user.multiwoz.policy_agenda_multiwoz + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.policy.user.multiwoz.policy\_vhus module +-------------------------------------------------------- + +.. automodule:: convlab.modules.policy.user.multiwoz.policy_vhus + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.policy.user.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.policy.user.rst.txt b/docs/build/html/_sources/convlab.modules.policy.user.rst.txt new file mode 100644 index 0000000..517e714 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.policy.user.rst.txt @@ -0,0 +1,29 @@ +convlab.modules.policy.user package +=================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.policy.user.multiwoz + +Submodules +---------- + +convlab.modules.policy.user.policy module +----------------------------------------- + +.. automodule:: convlab.modules.policy.user.policy + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.policy.user + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.rst.txt b/docs/build/html/_sources/convlab.modules.rst.txt new file mode 100644 index 0000000..0802cb0 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.rst.txt @@ -0,0 +1,27 @@ +convlab.modules package +======================= + +Subpackages +----------- + +.. toctree:: + + convlab.modules.action_decoder + convlab.modules.dst + convlab.modules.e2e + convlab.modules.nlg + convlab.modules.nlu + convlab.modules.policy + convlab.modules.state_encoder + convlab.modules.usr + convlab.modules.util + convlab.modules.word_dst + convlab.modules.word_policy + +Module contents +--------------- + +.. automodule:: convlab.modules + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.state_encoder.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.state_encoder.multiwoz.rst.txt new file mode 100644 index 0000000..e75f809 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.state_encoder.multiwoz.rst.txt @@ -0,0 +1,22 @@ +convlab.modules.state\_encoder.multiwoz package +=============================================== + +Submodules +---------- + +convlab.modules.state\_encoder.multiwoz.multiwoz\_state\_encoder module +----------------------------------------------------------------------- + +.. automodule:: convlab.modules.state_encoder.multiwoz.multiwoz_state_encoder + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.state_encoder.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.state_encoder.rst.txt b/docs/build/html/_sources/convlab.modules.state_encoder.rst.txt new file mode 100644 index 0000000..2e99403 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.state_encoder.rst.txt @@ -0,0 +1,17 @@ +convlab.modules.state\_encoder package +====================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.state_encoder.multiwoz + +Module contents +--------------- + +.. automodule:: convlab.modules.state_encoder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.usr.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.usr.multiwoz.rst.txt new file mode 100644 index 0000000..8604ff0 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.usr.multiwoz.rst.txt @@ -0,0 +1,22 @@ +convlab.modules.usr.multiwoz package +==================================== + +Submodules +---------- + +convlab.modules.usr.multiwoz.goal\_generator module +--------------------------------------------------- + +.. automodule:: convlab.modules.usr.multiwoz.goal_generator + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.usr.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.usr.rst.txt b/docs/build/html/_sources/convlab.modules.usr.rst.txt new file mode 100644 index 0000000..3d0c8ad --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.usr.rst.txt @@ -0,0 +1,29 @@ +convlab.modules.usr package +=========================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.usr.multiwoz + +Submodules +---------- + +convlab.modules.usr.user module +------------------------------- + +.. automodule:: convlab.modules.usr.user + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.usr + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.util.rst.txt b/docs/build/html/_sources/convlab.modules.util.rst.txt new file mode 100644 index 0000000..dc092a3 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.util.rst.txt @@ -0,0 +1,10 @@ +convlab.modules.util package +============================ + +Module contents +--------------- + +.. automodule:: convlab.modules.util + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.word_dst.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.word_dst.multiwoz.rst.txt new file mode 100644 index 0000000..e5f5e90 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.word_dst.multiwoz.rst.txt @@ -0,0 +1,38 @@ +convlab.modules.word\_dst.multiwoz package +========================================== + +Submodules +---------- + +convlab.modules.word\_dst.multiwoz.evaluate module +-------------------------------------------------- + +.. automodule:: convlab.modules.word_dst.multiwoz.evaluate + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_dst.multiwoz.mdbt module +---------------------------------------------- + +.. automodule:: convlab.modules.word_dst.multiwoz.mdbt + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_dst.multiwoz.mdbt\_util module +---------------------------------------------------- + +.. automodule:: convlab.modules.word_dst.multiwoz.mdbt_util + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.word_dst.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.word_dst.rst.txt b/docs/build/html/_sources/convlab.modules.word_dst.rst.txt new file mode 100644 index 0000000..91984cf --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.word_dst.rst.txt @@ -0,0 +1,17 @@ +convlab.modules.word\_dst package +================================= + +Subpackages +----------- + +.. toctree:: + + convlab.modules.word_dst.multiwoz + +Module contents +--------------- + +.. automodule:: convlab.modules.word_dst + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.model.rst.txt b/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.model.rst.txt new file mode 100644 index 0000000..a6de8ee --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.model.rst.txt @@ -0,0 +1,38 @@ +convlab.modules.word\_policy.multiwoz.mdrg.model package +======================================================== + +Submodules +---------- + +convlab.modules.word\_policy.multiwoz.mdrg.model.evaluator module +----------------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.model.evaluator + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.model.model module +------------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.model.model + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.model.policy module +-------------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.model.policy + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.model + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.rst.txt b/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.rst.txt new file mode 100644 index 0000000..700f996 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.rst.txt @@ -0,0 +1,54 @@ +convlab.modules.word\_policy.multiwoz.mdrg package +================================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.word_policy.multiwoz.mdrg.model + convlab.modules.word_policy.multiwoz.mdrg.utils + +Submodules +---------- + +convlab.modules.word\_policy.multiwoz.mdrg.create\_delex\_data module +--------------------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.create_delex_data + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.policy module +-------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.policy + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.test module +------------------------------------------------------ + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.test + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.train module +------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.train + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.utils.rst.txt b/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.utils.rst.txt new file mode 100644 index 0000000..7253b87 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.mdrg.utils.rst.txt @@ -0,0 +1,54 @@ +convlab.modules.word\_policy.multiwoz.mdrg.utils package +======================================================== + +Submodules +---------- + +convlab.modules.word\_policy.multiwoz.mdrg.utils.dbPointer module +----------------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.utils.dbPointer + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.utils.dbquery module +--------------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.utils.dbquery + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.utils.delexicalize module +-------------------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.utils.delexicalize + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.utils.nlp module +----------------------------------------------------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.utils.nlp + :members: + :undoc-members: + :show-inheritance: + +convlab.modules.word\_policy.multiwoz.mdrg.utils.util module +------------------------------------------------------------ + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.utils.util + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: convlab.modules.word_policy.multiwoz.mdrg.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.rst.txt b/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.rst.txt new file mode 100644 index 0000000..3583249 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.word_policy.multiwoz.rst.txt @@ -0,0 +1,17 @@ +convlab.modules.word\_policy.multiwoz package +============================================= + +Subpackages +----------- + +.. toctree:: + + convlab.modules.word_policy.multiwoz.mdrg + +Module contents +--------------- + +.. automodule:: convlab.modules.word_policy.multiwoz + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.modules.word_policy.rst.txt b/docs/build/html/_sources/convlab.modules.word_policy.rst.txt new file mode 100644 index 0000000..e3afcb5 --- /dev/null +++ b/docs/build/html/_sources/convlab.modules.word_policy.rst.txt @@ -0,0 +1,17 @@ +convlab.modules.word\_policy package +==================================== + +Subpackages +----------- + +.. toctree:: + + convlab.modules.word_policy.multiwoz + +Module contents +--------------- + +.. automodule:: convlab.modules.word_policy + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/convlab.rst.txt b/docs/build/html/_sources/convlab.rst.txt new file mode 100644 index 0000000..eae7741 --- /dev/null +++ b/docs/build/html/_sources/convlab.rst.txt @@ -0,0 +1,24 @@ +convlab package +=============== + +Subpackages +----------- + +.. toctree:: + + convlab.agent + convlab.env + convlab.evaluator + convlab.experiment + convlab.human_eval + convlab.lib + convlab.modules + convlab.spec + +Module contents +--------------- + +.. automodule:: convlab + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/build/html/_sources/tasktk.usr.rst.txt b/docs/build/html/_sources/convlab.spec.rst.txt similarity index 50% rename from docs/build/html/_sources/tasktk.usr.rst.txt rename to docs/build/html/_sources/convlab.spec.rst.txt index 24ac14a..6ca6ccd 100644 --- a/docs/build/html/_sources/tasktk.usr.rst.txt +++ b/docs/build/html/_sources/convlab.spec.rst.txt @@ -1,13 +1,13 @@ -tasktk.usr package -================== +convlab.spec package +==================== Submodules ---------- -tasktk.usr.simulator module ---------------------------- +convlab.spec.spec\_util module +------------------------------ -.. automodule:: tasktk.usr.simulator +.. automodule:: convlab.spec.spec_util :members: :undoc-members: :show-inheritance: @@ -16,7 +16,7 @@ tasktk.usr.simulator module Module contents --------------- -.. automodule:: tasktk.usr +.. automodule:: convlab.spec :members: :undoc-members: :show-inheritance: diff --git a/docs/build/html/_sources/index.rst.txt b/docs/build/html/_sources/index.rst.txt index 4fc7888..318c1b5 100644 --- a/docs/build/html/_sources/index.rst.txt +++ b/docs/build/html/_sources/index.rst.txt @@ -1,32 +1,14 @@ -.. Tasktk documentation master file, created by - sphinx-quickstart on Thu Jan 17 17:29:23 2019. +.. ConvLab documentation master file, created by + sphinx-quickstart on Tue Jul 2 16:38:15 2019. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -:github_url: https://github.com/xjli/DialogZone - -Tasktk documentation -================================== -This is a toolkit for developing task-oriented dialog system. We followed the classical pipeline framework, where there are 4 seperated components: NLU, DST, Policy and NLG. - -We offered the base class and some SOTA baseline models (coming soon) for each component. Specially, the NLU, DST and NLG models are trained individually, while the Policy is trained within a complete pipeline model in a RL-based manner. +Welcome to ConvLab's documentation! +=================================== .. toctree:: - :glob: - :maxdepth: 1 - :caption: Notes - - n - -.. toctree:: - :maxdepth: 1 - :caption: Package Reference - - nlu - -Introduction -============ -This is an introductino demo. + :maxdepth: 2 + :caption: Contents: diff --git a/docs/build/html/_sources/modules.rst.txt b/docs/build/html/_sources/modules.rst.txt index 270746c..f2b310d 100644 --- a/docs/build/html/_sources/modules.rst.txt +++ b/docs/build/html/_sources/modules.rst.txt @@ -1,7 +1,7 @@ -tasktk -====== +convlab +======= .. toctree:: :maxdepth: 4 - tasktk + convlab diff --git a/docs/build/html/_sources/nlu.rst.txt b/docs/build/html/_sources/nlu.rst.txt deleted file mode 100644 index 8c42b81..0000000 --- a/docs/build/html/_sources/nlu.rst.txt +++ /dev/null @@ -1,16 +0,0 @@ -NLU -######### -.. automodule:: task.nlu - -.. autoclass:: NLU - :members: - -NLU class ---------------------------------- -.. autoclass:: RuleNLU - :members: - -.. autoclass:: TrainableNLU - :members: - -Metric-like class \ No newline at end of file diff --git a/docs/build/html/_sources/tasktk.dialog_agent.rst.txt b/docs/build/html/_sources/tasktk.dialog_agent.rst.txt deleted file mode 100644 index f140cb8..0000000 --- a/docs/build/html/_sources/tasktk.dialog_agent.rst.txt +++ /dev/null @@ -1,22 +0,0 @@ -tasktk.dialog\_agent package -============================ - -Submodules ----------- - -tasktk.dialog\_agent.system module ----------------------------------- - -.. automodule:: tasktk.dialog_agent.system - :members: - :undoc-members: - :show-inheritance: - - -Module contents ---------------- - -.. automodule:: tasktk.dialog_agent - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/build/html/_sources/tasktk.dst.rst.txt b/docs/build/html/_sources/tasktk.dst.rst.txt deleted file mode 100644 index b1179de..0000000 --- a/docs/build/html/_sources/tasktk.dst.rst.txt +++ /dev/null @@ -1,22 +0,0 @@ -tasktk.dst package -================== - -Submodules ----------- - -tasktk.dst.state\_tracker module --------------------------------- - -.. automodule:: tasktk.dst.state_tracker - :members: - :undoc-members: - :show-inheritance: - - -Module contents ---------------- - -.. automodule:: tasktk.dst - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/build/html/_sources/tasktk.nlg.rst.txt b/docs/build/html/_sources/tasktk.nlg.rst.txt deleted file mode 100644 index d6a8630..0000000 --- a/docs/build/html/_sources/tasktk.nlg.rst.txt +++ /dev/null @@ -1,22 +0,0 @@ -tasktk.nlg package -================== - -Submodules ----------- - -tasktk.nlg.nlg module ---------------------- - -.. automodule:: tasktk.nlg.nlg - :members: - :undoc-members: - :show-inheritance: - - -Module contents ---------------- - -.. automodule:: tasktk.nlg - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/build/html/_sources/tasktk.nlu.rst.txt b/docs/build/html/_sources/tasktk.nlu.rst.txt deleted file mode 100644 index c8bfba9..0000000 --- a/docs/build/html/_sources/tasktk.nlu.rst.txt +++ /dev/null @@ -1,30 +0,0 @@ -tasktk.nlu package -================== - -Submodules ----------- - -tasktk.nlu.error module ------------------------ - -.. automodule:: tasktk.nlu.error - :members: - :undoc-members: - :show-inheritance: - -tasktk.nlu.nlu module ---------------------- - -.. automodule:: tasktk.nlu.nlu - :members: - :undoc-members: - :show-inheritance: - - -Module contents ---------------- - -.. automodule:: tasktk.nlu - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/build/html/_sources/tasktk.policy.rst.txt b/docs/build/html/_sources/tasktk.policy.rst.txt deleted file mode 100644 index e4e237e..0000000 --- a/docs/build/html/_sources/tasktk.policy.rst.txt +++ /dev/null @@ -1,30 +0,0 @@ -tasktk.policy package -===================== - -Submodules ----------- - -tasktk.policy.policy module ---------------------------- - -.. automodule:: tasktk.policy.policy - :members: - :undoc-members: - :show-inheritance: - -tasktk.policy.policy\_user\_rule module ---------------------------------------- - -.. automodule:: tasktk.policy.policy_user_rule - :members: - :undoc-members: - :show-inheritance: - - -Module contents ---------------- - -.. automodule:: tasktk.policy - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/build/html/_sources/tasktk.rst.txt b/docs/build/html/_sources/tasktk.rst.txt deleted file mode 100644 index 0555490..0000000 --- a/docs/build/html/_sources/tasktk.rst.txt +++ /dev/null @@ -1,35 +0,0 @@ -tasktk package -============== - -Subpackages ------------ - -.. toctree:: - - tasktk.dialog_agent - tasktk.dst - tasktk.nlg - tasktk.nlu - tasktk.policy - tasktk.usr - tasktk.util - -Submodules ----------- - -tasktk.dialog\_session module ------------------------------ - -.. automodule:: tasktk.dialog_session - :members: - :undoc-members: - :show-inheritance: - - -Module contents ---------------- - -.. automodule:: tasktk - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/build/html/_static/alabaster.css b/docs/build/html/_static/alabaster.css new file mode 100644 index 0000000..0eddaeb --- /dev/null +++ b/docs/build/html/_static/alabaster.css @@ -0,0 +1,701 @@ +@import url("basic.css"); + +/* -- page layout ----------------------------------------------------------- */ + +body { + font-family: Georgia, serif; + font-size: 17px; + background-color: #fff; + color: #000; + margin: 0; + padding: 0; +} + + +div.document { + width: 940px; + margin: 30px auto 0 auto; +} + +div.documentwrapper { + float: left; + width: 100%; +} + +div.bodywrapper { + margin: 0 0 0 220px; +} + +div.sphinxsidebar { + width: 220px; + font-size: 14px; + line-height: 1.5; +} + +hr { + border: 1px solid #B1B4B6; +} + +div.body { + background-color: #fff; + color: #3E4349; + padding: 0 30px 0 30px; +} + +div.body > .section { + text-align: left; +} + +div.footer { + width: 940px; + margin: 20px auto 30px auto; + font-size: 14px; + color: #888; + text-align: right; +} + +div.footer a { + color: #888; +} + +p.caption { + font-family: inherit; + font-size: inherit; +} + + +div.relations { + display: none; +} + + +div.sphinxsidebar a { + color: #444; + text-decoration: none; + border-bottom: 1px dotted #999; +} + +div.sphinxsidebar a:hover { + border-bottom: 1px solid #999; +} + +div.sphinxsidebarwrapper { + padding: 18px 10px; +} + +div.sphinxsidebarwrapper p.logo { + padding: 0; + margin: -10px 0 0 0px; + text-align: center; +} + +div.sphinxsidebarwrapper h1.logo { + margin-top: -10px; + text-align: center; + margin-bottom: 5px; + text-align: left; +} + +div.sphinxsidebarwrapper h1.logo-name { + margin-top: 0px; +} + +div.sphinxsidebarwrapper p.blurb { + margin-top: 0; + font-style: normal; +} + +div.sphinxsidebar h3, +div.sphinxsidebar h4 { + font-family: Georgia, serif; + color: #444; + font-size: 24px; + font-weight: normal; + margin: 0 0 5px 0; + padding: 0; +} + +div.sphinxsidebar h4 { + font-size: 20px; +} + +div.sphinxsidebar h3 a { + color: #444; +} + +div.sphinxsidebar p.logo a, +div.sphinxsidebar h3 a, +div.sphinxsidebar p.logo a:hover, +div.sphinxsidebar h3 a:hover { + border: none; +} + +div.sphinxsidebar p { + color: #555; + margin: 10px 0; +} + +div.sphinxsidebar ul { + margin: 10px 0; + padding: 0; + color: #000; +} + +div.sphinxsidebar ul li.toctree-l1 > a { + font-size: 120%; +} + +div.sphinxsidebar ul li.toctree-l2 > a { + font-size: 110%; +} + +div.sphinxsidebar input { + border: 1px solid #CCC; + font-family: Georgia, serif; + font-size: 1em; +} + +div.sphinxsidebar hr { + border: none; + height: 1px; + color: #AAA; + background: #AAA; + + text-align: left; + margin-left: 0; + width: 50%; +} + +div.sphinxsidebar .badge { + border-bottom: none; +} + +div.sphinxsidebar .badge:hover { + border-bottom: none; +} + +/* To address an issue with donation coming after search */ +div.sphinxsidebar h3.donation { + margin-top: 10px; +} + +/* -- body styles ----------------------------------------------------------- */ + +a { + color: #004B6B; + text-decoration: underline; +} + +a:hover { + color: #6D4100; + text-decoration: underline; +} + +div.body h1, +div.body h2, +div.body h3, +div.body h4, +div.body h5, +div.body h6 { + font-family: Georgia, serif; + font-weight: normal; + margin: 30px 0px 10px 0px; + padding: 0; +} + +div.body h1 { margin-top: 0; padding-top: 0; font-size: 240%; } +div.body h2 { font-size: 180%; } +div.body h3 { font-size: 150%; } +div.body h4 { font-size: 130%; } +div.body h5 { font-size: 100%; } +div.body h6 { font-size: 100%; } + +a.headerlink { + color: #DDD; + padding: 0 4px; + text-decoration: none; +} + +a.headerlink:hover { + color: #444; + background: #EAEAEA; +} + +div.body p, div.body dd, div.body li { + line-height: 1.4em; +} + +div.admonition { + margin: 20px 0px; + padding: 10px 30px; + background-color: #EEE; + border: 1px solid #CCC; +} + +div.admonition tt.xref, div.admonition code.xref, div.admonition a tt { + background-color: #FBFBFB; + border-bottom: 1px solid #fafafa; +} + +div.admonition p.admonition-title { + font-family: Georgia, serif; + font-weight: normal; + font-size: 24px; + margin: 0 0 10px 0; + padding: 0; + line-height: 1; +} + +div.admonition p.last { + margin-bottom: 0; +} + +div.highlight { + background-color: #fff; +} + +dt:target, .highlight { + background: #FAF3E8; +} + +div.warning { + background-color: #FCC; + border: 1px solid #FAA; +} + +div.danger { + background-color: #FCC; + border: 1px solid #FAA; + -moz-box-shadow: 2px 2px 4px #D52C2C; + -webkit-box-shadow: 2px 2px 4px #D52C2C; + box-shadow: 2px 2px 4px #D52C2C; +} + +div.error { + background-color: #FCC; + border: 1px solid #FAA; + -moz-box-shadow: 2px 2px 4px #D52C2C; + -webkit-box-shadow: 2px 2px 4px #D52C2C; + box-shadow: 2px 2px 4px #D52C2C; +} + +div.caution { + background-color: #FCC; + border: 1px solid #FAA; +} + +div.attention { + background-color: #FCC; + border: 1px solid #FAA; +} + +div.important { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.note { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.tip { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.hint { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.seealso { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.topic { + background-color: #EEE; +} + +p.admonition-title { + display: inline; +} + +p.admonition-title:after { + content: ":"; +} + +pre, tt, code { + font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace; + font-size: 0.9em; +} + +.hll { + background-color: #FFC; + margin: 0 -12px; + padding: 0 12px; + display: block; +} + +img.screenshot { +} + +tt.descname, tt.descclassname, code.descname, code.descclassname { + font-size: 0.95em; +} + +tt.descname, code.descname { + padding-right: 0.08em; +} + +img.screenshot { + -moz-box-shadow: 2px 2px 4px #EEE; + -webkit-box-shadow: 2px 2px 4px #EEE; + box-shadow: 2px 2px 4px #EEE; +} + +table.docutils { + border: 1px solid #888; + -moz-box-shadow: 2px 2px 4px #EEE; + -webkit-box-shadow: 2px 2px 4px #EEE; + box-shadow: 2px 2px 4px #EEE; +} + +table.docutils td, table.docutils th { + border: 1px solid #888; + padding: 0.25em 0.7em; +} + +table.field-list, table.footnote { + border: none; + -moz-box-shadow: none; + -webkit-box-shadow: none; + box-shadow: none; +} + +table.footnote { + margin: 15px 0; + width: 100%; + border: 1px solid #EEE; + background: #FDFDFD; + font-size: 0.9em; +} + +table.footnote + table.footnote { + margin-top: -15px; + border-top: none; +} + +table.field-list th { + padding: 0 0.8em 0 0; +} + +table.field-list td { + padding: 0; +} + +table.field-list p { + margin-bottom: 0.8em; +} + +/* Cloned from + * https://github.com/sphinx-doc/sphinx/commit/ef60dbfce09286b20b7385333d63a60321784e68 + */ +.field-name { + -moz-hyphens: manual; + -ms-hyphens: manual; + -webkit-hyphens: manual; + hyphens: manual; +} + +table.footnote td.label { + width: .1px; + padding: 0.3em 0 0.3em 0.5em; +} + +table.footnote td { + padding: 0.3em 0.5em; +} + +dl { + margin: 0; + padding: 0; +} + +dl dd { + margin-left: 30px; +} + +blockquote { + margin: 0 0 0 30px; + padding: 0; +} + +ul, ol { + /* Matches the 30px from the narrow-screen "li > ul" selector below */ + margin: 10px 0 10px 30px; + padding: 0; +} + +pre { + background: #EEE; + padding: 7px 30px; + margin: 15px 0px; + line-height: 1.3em; +} + +div.viewcode-block:target { + background: #ffd; +} + +dl pre, blockquote pre, li pre { + margin-left: 0; + padding-left: 30px; +} + +tt, code { + background-color: #ecf0f3; + color: #222; + /* padding: 1px 2px; */ +} + +tt.xref, code.xref, a tt { + background-color: #FBFBFB; + border-bottom: 1px solid #fff; +} + +a.reference { + text-decoration: none; + border-bottom: 1px dotted #004B6B; +} + +/* Don't put an underline on images */ +a.image-reference, a.image-reference:hover { + border-bottom: none; +} + +a.reference:hover { + border-bottom: 1px solid #6D4100; +} + +a.footnote-reference { + text-decoration: none; + font-size: 0.7em; + vertical-align: top; + border-bottom: 1px dotted #004B6B; +} + +a.footnote-reference:hover { + border-bottom: 1px solid #6D4100; +} + +a:hover tt, a:hover code { + background: #EEE; +} + + +@media screen and (max-width: 870px) { + + div.sphinxsidebar { + display: none; + } + + div.document { + width: 100%; + + } + + div.documentwrapper { + margin-left: 0; + margin-top: 0; + margin-right: 0; + margin-bottom: 0; + } + + div.bodywrapper { + margin-top: 0; + margin-right: 0; + margin-bottom: 0; + margin-left: 0; + } + + ul { + margin-left: 0; + } + + li > ul { + /* Matches the 30px from the "ul, ol" selector above */ + margin-left: 30px; + } + + .document { + width: auto; + } + + .footer { + width: auto; + } + + .bodywrapper { + margin: 0; + } + + .footer { + width: auto; + } + + .github { + display: none; + } + + + +} + + + +@media screen and (max-width: 875px) { + + body { + margin: 0; + padding: 20px 30px; + } + + div.documentwrapper { + float: none; + background: #fff; + } + + div.sphinxsidebar { + display: block; + float: none; + width: 102.5%; + margin: 50px -30px -20px -30px; + padding: 10px 20px; + background: #333; + color: #FFF; + } + + div.sphinxsidebar h3, div.sphinxsidebar h4, div.sphinxsidebar p, + div.sphinxsidebar h3 a { + color: #fff; + } + + div.sphinxsidebar a { + color: #AAA; + } + + div.sphinxsidebar p.logo { + display: none; + } + + div.document { + width: 100%; + margin: 0; + } + + div.footer { + display: none; + } + + div.bodywrapper { + margin: 0; + } + + div.body { + min-height: 0; + padding: 0; + } + + .rtd_doc_footer { + display: none; + } + + .document { + width: auto; + } + + .footer { + width: auto; + } + + .footer { + width: auto; + } + + .github { + display: none; + } +} + + +/* misc. */ + +.revsys-inline { + display: none!important; +} + +/* Make nested-list/multi-paragraph items look better in Releases changelog + * pages. Without this, docutils' magical list fuckery causes inconsistent + * formatting between different release sub-lists. + */ +div#changelog > div.section > ul > li > p:only-child { + margin-bottom: 0; +} + +/* Hide fugly table cell borders in ..bibliography:: directive output */ +table.docutils.citation, table.docutils.citation td, table.docutils.citation th { + border: none; + /* Below needed in some edge cases; if not applied, bottom shadows appear */ + -moz-box-shadow: none; + -webkit-box-shadow: none; + box-shadow: none; +} + + +/* relbar */ + +.related { + line-height: 30px; + width: 100%; + font-size: 0.9rem; +} + +.related.top { + border-bottom: 1px solid #EEE; + margin-bottom: 20px; +} + +.related.bottom { + border-top: 1px solid #EEE; +} + +.related ul { + padding: 0; + margin: 0; + list-style: none; +} + +.related li { + display: inline; +} + +nav#rellinks { + float: right; +} + +nav#rellinks li+li:before { + content: "|"; +} + +nav#breadcrumbs li+li:before { + content: "\00BB"; +} + +/* Hide certain items when printing */ +@media print { + div.related { + display: none; + } +} \ No newline at end of file diff --git a/docs/build/html/_static/basic.css b/docs/build/html/_static/basic.css index 104f076..0807176 100644 --- a/docs/build/html/_static/basic.css +++ b/docs/build/html/_static/basic.css @@ -4,7 +4,7 @@ * * Sphinx stylesheet -- basic theme. * - * :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ diff --git a/docs/build/html/_static/css/badge_only.css b/docs/build/html/_static/css/badge_only.css index 323730a..3c33cef 100644 --- a/docs/build/html/_static/css/badge_only.css +++ b/docs/build/html/_static/css/badge_only.css @@ -1 +1 @@ -.fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../fonts/fontawesome-webfont.eot");src:url("../fonts/fontawesome-webfont.eot?#iefix") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff") format("woff"),url("../fonts/fontawesome-webfont.ttf") format("truetype"),url("../fonts/fontawesome-webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} +.fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../fonts/fontawesome-webfont.eot");src:url("../fonts/fontawesome-webfont.eot?#iefix") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff") format("woff"),url("../fonts/fontawesome-webfont.ttf") format("truetype"),url("../fonts/fontawesome-webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} diff --git a/docs/build/html/_static/css/theme.css b/docs/build/html/_static/css/theme.css index b19dbfe..aed8cef 100644 --- a/docs/build/html/_static/css/theme.css +++ b/docs/build/html/_static/css/theme.css @@ -1,6 +1,6 @@ -/* sphinx_rtd_theme version 0.4.2 | MIT license */ -/* Built 20181005 13:10 */ -*{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}article,aside,details,figcaption,figure,footer,header,hgroup,nav,section{display:block}audio,canvas,video{display:inline-block;*display:inline;*zoom:1}audio:not([controls]){display:none}[hidden]{display:none}*{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}html{font-size:100%;-webkit-text-size-adjust:100%;-ms-text-size-adjust:100%}body{margin:0}a:hover,a:active{outline:0}abbr[title]{border-bottom:1px dotted}b,strong{font-weight:bold}blockquote{margin:0}dfn{font-style:italic}ins{background:#ff9;color:#000;text-decoration:none}mark{background:#ff0;color:#000;font-style:italic;font-weight:bold}pre,code,.rst-content tt,.rst-content code,kbd,samp{font-family:monospace,serif;_font-family:"courier new",monospace;font-size:1em}pre{white-space:pre}q{quotes:none}q:before,q:after{content:"";content:none}small{font-size:85%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sup{top:-0.5em}sub{bottom:-0.25em}ul,ol,dl{margin:0;padding:0;list-style:none;list-style-image:none}li{list-style:none}dd{margin:0}img{border:0;-ms-interpolation-mode:bicubic;vertical-align:middle;max-width:100%}svg:not(:root){overflow:hidden}figure{margin:0}form{margin:0}fieldset{border:0;margin:0;padding:0}label{cursor:pointer}legend{border:0;*margin-left:-7px;padding:0;white-space:normal}button,input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}button,input{line-height:normal}button,input[type="button"],input[type="reset"],input[type="submit"]{cursor:pointer;-webkit-appearance:button;*overflow:visible}button[disabled],input[disabled]{cursor:default}input[type="checkbox"],input[type="radio"]{box-sizing:border-box;padding:0;*width:13px;*height:13px}input[type="search"]{-webkit-appearance:textfield;-moz-box-sizing:content-box;-webkit-box-sizing:content-box;box-sizing:content-box}input[type="search"]::-webkit-search-decoration,input[type="search"]::-webkit-search-cancel-button{-webkit-appearance:none}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}textarea{overflow:auto;vertical-align:top;resize:vertical}table{border-collapse:collapse;border-spacing:0}td{vertical-align:top}.chromeframe{margin:.2em 0;background:#ccc;color:#000;padding:.2em 0}.ir{display:block;border:0;text-indent:-999em;overflow:hidden;background-color:transparent;background-repeat:no-repeat;text-align:left;direction:ltr;*line-height:0}.ir br{display:none}.hidden{display:none !important;visibility:hidden}.visuallyhidden{border:0;clip:rect(0 0 0 0);height:1px;margin:-1px;overflow:hidden;padding:0;position:absolute;width:1px}.visuallyhidden.focusable:active,.visuallyhidden.focusable:focus{clip:auto;height:auto;margin:0;overflow:visible;position:static;width:auto}.invisible{visibility:hidden}.relative{position:relative}big,small{font-size:100%}@media print{html,body,section{background:none !important}*{box-shadow:none !important;text-shadow:none !important;filter:none !important;-ms-filter:none !important}a,a:visited{text-decoration:underline}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:.5cm}p,h2,.rst-content .toctree-wrapper p.caption,h3{orphans:3;widows:3}h2,.rst-content .toctree-wrapper p.caption,h3{page-break-after:avoid}}.fa:before,.wy-menu-vertical li span.toctree-expand:before,.wy-menu-vertical li.on a span.toctree-expand:before,.wy-menu-vertical li.current>a span.toctree-expand:before,.rst-content .admonition-title:before,.rst-content h1 .headerlink:before,.rst-content h2 .headerlink:before,.rst-content h3 .headerlink:before,.rst-content h4 .headerlink:before,.rst-content h5 .headerlink:before,.rst-content h6 .headerlink:before,.rst-content dl dt .headerlink:before,.rst-content p.caption .headerlink:before,.rst-content table>caption .headerlink:before,.rst-content tt.download span:first-child:before,.rst-content code.download span:first-child:before,.icon:before,.wy-dropdown .caret:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.wy-alert,.rst-content .note,.rst-content .attention,.rst-content .caution,.rst-content .danger,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .tip,.rst-content .warning,.rst-content .seealso,.rst-content .admonition-todo,.rst-content .admonition,.btn,input[type="text"],input[type="password"],input[type="email"],input[type="url"],input[type="date"],input[type="month"],input[type="time"],input[type="datetime"],input[type="datetime-local"],input[type="week"],input[type="number"],input[type="search"],input[type="tel"],input[type="color"],select,textarea,.wy-menu-vertical li.on a,.wy-menu-vertical li.current>a,.wy-side-nav-search>a,.wy-side-nav-search .wy-dropdown>a,.wy-nav-top a{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}/*! +/* sphinx_rtd_theme version 0.4.3 | MIT license */ +/* Built 20190212 16:02 */ +*{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}article,aside,details,figcaption,figure,footer,header,hgroup,nav,section{display:block}audio,canvas,video{display:inline-block;*display:inline;*zoom:1}audio:not([controls]){display:none}[hidden]{display:none}*{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}html{font-size:100%;-webkit-text-size-adjust:100%;-ms-text-size-adjust:100%}body{margin:0}a:hover,a:active{outline:0}abbr[title]{border-bottom:1px dotted}b,strong{font-weight:bold}blockquote{margin:0}dfn{font-style:italic}ins{background:#ff9;color:#000;text-decoration:none}mark{background:#ff0;color:#000;font-style:italic;font-weight:bold}pre,code,.rst-content tt,.rst-content code,kbd,samp{font-family:monospace,serif;_font-family:"courier new",monospace;font-size:1em}pre{white-space:pre}q{quotes:none}q:before,q:after{content:"";content:none}small{font-size:85%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sup{top:-0.5em}sub{bottom:-0.25em}ul,ol,dl{margin:0;padding:0;list-style:none;list-style-image:none}li{list-style:none}dd{margin:0}img{border:0;-ms-interpolation-mode:bicubic;vertical-align:middle;max-width:100%}svg:not(:root){overflow:hidden}figure{margin:0}form{margin:0}fieldset{border:0;margin:0;padding:0}label{cursor:pointer}legend{border:0;*margin-left:-7px;padding:0;white-space:normal}button,input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}button,input{line-height:normal}button,input[type="button"],input[type="reset"],input[type="submit"]{cursor:pointer;-webkit-appearance:button;*overflow:visible}button[disabled],input[disabled]{cursor:default}input[type="checkbox"],input[type="radio"]{box-sizing:border-box;padding:0;*width:13px;*height:13px}input[type="search"]{-webkit-appearance:textfield;-moz-box-sizing:content-box;-webkit-box-sizing:content-box;box-sizing:content-box}input[type="search"]::-webkit-search-decoration,input[type="search"]::-webkit-search-cancel-button{-webkit-appearance:none}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}textarea{overflow:auto;vertical-align:top;resize:vertical}table{border-collapse:collapse;border-spacing:0}td{vertical-align:top}.chromeframe{margin:.2em 0;background:#ccc;color:#000;padding:.2em 0}.ir{display:block;border:0;text-indent:-999em;overflow:hidden;background-color:transparent;background-repeat:no-repeat;text-align:left;direction:ltr;*line-height:0}.ir br{display:none}.hidden{display:none !important;visibility:hidden}.visuallyhidden{border:0;clip:rect(0 0 0 0);height:1px;margin:-1px;overflow:hidden;padding:0;position:absolute;width:1px}.visuallyhidden.focusable:active,.visuallyhidden.focusable:focus{clip:auto;height:auto;margin:0;overflow:visible;position:static;width:auto}.invisible{visibility:hidden}.relative{position:relative}big,small{font-size:100%}@media print{html,body,section{background:none !important}*{box-shadow:none !important;text-shadow:none !important;filter:none !important;-ms-filter:none !important}a,a:visited{text-decoration:underline}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:.5cm}p,h2,.rst-content .toctree-wrapper p.caption,h3{orphans:3;widows:3}h2,.rst-content .toctree-wrapper p.caption,h3{page-break-after:avoid}}.fa:before,.wy-menu-vertical li span.toctree-expand:before,.wy-menu-vertical li.on a span.toctree-expand:before,.wy-menu-vertical li.current>a span.toctree-expand:before,.rst-content .admonition-title:before,.rst-content h1 .headerlink:before,.rst-content h2 .headerlink:before,.rst-content h3 .headerlink:before,.rst-content h4 .headerlink:before,.rst-content h5 .headerlink:before,.rst-content h6 .headerlink:before,.rst-content dl dt .headerlink:before,.rst-content p.caption .headerlink:before,.rst-content table>caption .headerlink:before,.rst-content .code-block-caption .headerlink:before,.rst-content tt.download span:first-child:before,.rst-content code.download span:first-child:before,.icon:before,.wy-dropdown .caret:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.wy-alert,.rst-content .note,.rst-content .attention,.rst-content .caution,.rst-content .danger,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .tip,.rst-content .warning,.rst-content .seealso,.rst-content .admonition-todo,.rst-content .admonition,.btn,input[type="text"],input[type="password"],input[type="email"],input[type="url"],input[type="date"],input[type="month"],input[type="time"],input[type="datetime"],input[type="datetime-local"],input[type="week"],input[type="number"],input[type="search"],input[type="tel"],input[type="color"],select,textarea,.wy-menu-vertical li.on a,.wy-menu-vertical li.current>a,.wy-side-nav-search>a,.wy-side-nav-search .wy-dropdown>a,.wy-nav-top a{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}/*! * Font Awesome 4.7.0 by @davegandy - http://fontawesome.io - @fontawesome * License - http://fontawesome.io/license (Font: SIL OFL 1.1, CSS: MIT License) - */@font-face{font-family:'FontAwesome';src:url("../fonts/fontawesome-webfont.eot?v=4.7.0");src:url("../fonts/fontawesome-webfont.eot?#iefix&v=4.7.0") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff2?v=4.7.0") format("woff2"),url("../fonts/fontawesome-webfont.woff?v=4.7.0") format("woff"),url("../fonts/fontawesome-webfont.ttf?v=4.7.0") format("truetype"),url("../fonts/fontawesome-webfont.svg?v=4.7.0#fontawesomeregular") format("svg");font-weight:normal;font-style:normal}.fa,.wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.current>a span.toctree-expand,.rst-content .admonition-title,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content dl dt .headerlink,.rst-content p.caption .headerlink,.rst-content table>caption .headerlink,.rst-content tt.download span:first-child,.rst-content code.download span:first-child,.icon{display:inline-block;font:normal normal normal 14px/1 FontAwesome;font-size:inherit;text-rendering:auto;-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale}.fa-lg{font-size:1.3333333333em;line-height:.75em;vertical-align:-15%}.fa-2x{font-size:2em}.fa-3x{font-size:3em}.fa-4x{font-size:4em}.fa-5x{font-size:5em}.fa-fw{width:1.2857142857em;text-align:center}.fa-ul{padding-left:0;margin-left:2.1428571429em;list-style-type:none}.fa-ul>li{position:relative}.fa-li{position:absolute;left:-2.1428571429em;width:2.1428571429em;top:.1428571429em;text-align:center}.fa-li.fa-lg{left:-1.8571428571em}.fa-border{padding:.2em .25em .15em;border:solid 0.08em #eee;border-radius:.1em}.fa-pull-left{float:left}.fa-pull-right{float:right}.fa.fa-pull-left,.wy-menu-vertical li span.fa-pull-left.toctree-expand,.wy-menu-vertical li.on a span.fa-pull-left.toctree-expand,.wy-menu-vertical li.current>a span.fa-pull-left.toctree-expand,.rst-content .fa-pull-left.admonition-title,.rst-content h1 .fa-pull-left.headerlink,.rst-content h2 .fa-pull-left.headerlink,.rst-content h3 .fa-pull-left.headerlink,.rst-content h4 .fa-pull-left.headerlink,.rst-content h5 .fa-pull-left.headerlink,.rst-content h6 .fa-pull-left.headerlink,.rst-content dl dt .fa-pull-left.headerlink,.rst-content p.caption .fa-pull-left.headerlink,.rst-content table>caption .fa-pull-left.headerlink,.rst-content tt.download span.fa-pull-left:first-child,.rst-content code.download span.fa-pull-left:first-child,.fa-pull-left.icon{margin-right:.3em}.fa.fa-pull-right,.wy-menu-vertical li span.fa-pull-right.toctree-expand,.wy-menu-vertical li.on a span.fa-pull-right.toctree-expand,.wy-menu-vertical li.current>a span.fa-pull-right.toctree-expand,.rst-content .fa-pull-right.admonition-title,.rst-content h1 .fa-pull-right.headerlink,.rst-content h2 .fa-pull-right.headerlink,.rst-content h3 .fa-pull-right.headerlink,.rst-content h4 .fa-pull-right.headerlink,.rst-content h5 .fa-pull-right.headerlink,.rst-content h6 .fa-pull-right.headerlink,.rst-content dl dt .fa-pull-right.headerlink,.rst-content p.caption .fa-pull-right.headerlink,.rst-content table>caption .fa-pull-right.headerlink,.rst-content tt.download span.fa-pull-right:first-child,.rst-content code.download span.fa-pull-right:first-child,.fa-pull-right.icon{margin-left:.3em}.pull-right{float:right}.pull-left{float:left}.fa.pull-left,.wy-menu-vertical li span.pull-left.toctree-expand,.wy-menu-vertical li.on a span.pull-left.toctree-expand,.wy-menu-vertical li.current>a span.pull-left.toctree-expand,.rst-content .pull-left.admonition-title,.rst-content h1 .pull-left.headerlink,.rst-content h2 .pull-left.headerlink,.rst-content h3 .pull-left.headerlink,.rst-content h4 .pull-left.headerlink,.rst-content h5 .pull-left.headerlink,.rst-content h6 .pull-left.headerlink,.rst-content dl dt .pull-left.headerlink,.rst-content p.caption .pull-left.headerlink,.rst-content table>caption .pull-left.headerlink,.rst-content tt.download span.pull-left:first-child,.rst-content code.download span.pull-left:first-child,.pull-left.icon{margin-right:.3em}.fa.pull-right,.wy-menu-vertical li span.pull-right.toctree-expand,.wy-menu-vertical li.on a span.pull-right.toctree-expand,.wy-menu-vertical li.current>a span.pull-right.toctree-expand,.rst-content .pull-right.admonition-title,.rst-content h1 .pull-right.headerlink,.rst-content h2 .pull-right.headerlink,.rst-content h3 .pull-right.headerlink,.rst-content h4 .pull-right.headerlink,.rst-content h5 .pull-right.headerlink,.rst-content h6 .pull-right.headerlink,.rst-content dl dt .pull-right.headerlink,.rst-content p.caption .pull-right.headerlink,.rst-content table>caption .pull-right.headerlink,.rst-content tt.download span.pull-right:first-child,.rst-content code.download span.pull-right:first-child,.pull-right.icon{margin-left:.3em}.fa-spin{-webkit-animation:fa-spin 2s infinite linear;animation:fa-spin 2s infinite linear}.fa-pulse{-webkit-animation:fa-spin 1s infinite steps(8);animation:fa-spin 1s infinite steps(8)}@-webkit-keyframes fa-spin{0%{-webkit-transform:rotate(0deg);transform:rotate(0deg)}100%{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes fa-spin{0%{-webkit-transform:rotate(0deg);transform:rotate(0deg)}100%{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.fa-rotate-90{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=1)";-webkit-transform:rotate(90deg);-ms-transform:rotate(90deg);transform:rotate(90deg)}.fa-rotate-180{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=2)";-webkit-transform:rotate(180deg);-ms-transform:rotate(180deg);transform:rotate(180deg)}.fa-rotate-270{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=3)";-webkit-transform:rotate(270deg);-ms-transform:rotate(270deg);transform:rotate(270deg)}.fa-flip-horizontal{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=0, mirror=1)";-webkit-transform:scale(-1, 1);-ms-transform:scale(-1, 1);transform:scale(-1, 1)}.fa-flip-vertical{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=2, mirror=1)";-webkit-transform:scale(1, -1);-ms-transform:scale(1, -1);transform:scale(1, -1)}:root .fa-rotate-90,:root .fa-rotate-180,:root .fa-rotate-270,:root .fa-flip-horizontal,:root .fa-flip-vertical{filter:none}.fa-stack{position:relative;display:inline-block;width:2em;height:2em;line-height:2em;vertical-align:middle}.fa-stack-1x,.fa-stack-2x{position:absolute;left:0;width:100%;text-align:center}.fa-stack-1x{line-height:inherit}.fa-stack-2x{font-size:2em}.fa-inverse{color:#fff}.fa-glass:before{content:""}.fa-music:before{content:""}.fa-search:before,.icon-search:before{content:""}.fa-envelope-o:before{content:""}.fa-heart:before{content:""}.fa-star:before{content:""}.fa-star-o:before{content:""}.fa-user:before{content:""}.fa-film:before{content:""}.fa-th-large:before{content:""}.fa-th:before{content:""}.fa-th-list:before{content:""}.fa-check:before{content:""}.fa-remove:before,.fa-close:before,.fa-times:before{content:""}.fa-search-plus:before{content:""}.fa-search-minus:before{content:""}.fa-power-off:before{content:""}.fa-signal:before{content:""}.fa-gear:before,.fa-cog:before{content:""}.fa-trash-o:before{content:""}.fa-home:before,.icon-home:before{content:""}.fa-file-o:before{content:""}.fa-clock-o:before{content:""}.fa-road:before{content:""}.fa-download:before,.rst-content tt.download span:first-child:before,.rst-content code.download span:first-child:before{content:""}.fa-arrow-circle-o-down:before{content:""}.fa-arrow-circle-o-up:before{content:""}.fa-inbox:before{content:""}.fa-play-circle-o:before{content:""}.fa-rotate-right:before,.fa-repeat:before{content:""}.fa-refresh:before{content:""}.fa-list-alt:before{content:""}.fa-lock:before{content:""}.fa-flag:before{content:""}.fa-headphones:before{content:""}.fa-volume-off:before{content:""}.fa-volume-down:before{content:""}.fa-volume-up:before{content:""}.fa-qrcode:before{content:""}.fa-barcode:before{content:""}.fa-tag:before{content:""}.fa-tags:before{content:""}.fa-book:before,.icon-book:before{content:""}.fa-bookmark:before{content:""}.fa-print:before{content:""}.fa-camera:before{content:""}.fa-font:before{content:""}.fa-bold:before{content:""}.fa-italic:before{content:""}.fa-text-height:before{content:""}.fa-text-width:before{content:""}.fa-align-left:before{content:""}.fa-align-center:before{content:""}.fa-align-right:before{content:""}.fa-align-justify:before{content:""}.fa-list:before{content:""}.fa-dedent:before,.fa-outdent:before{content:""}.fa-indent:before{content:""}.fa-video-camera:before{content:""}.fa-photo:before,.fa-image:before,.fa-picture-o:before{content:""}.fa-pencil:before{content:""}.fa-map-marker:before{content:""}.fa-adjust:before{content:""}.fa-tint:before{content:""}.fa-edit:before,.fa-pencil-square-o:before{content:""}.fa-share-square-o:before{content:""}.fa-check-square-o:before{content:""}.fa-arrows:before{content:""}.fa-step-backward:before{content:""}.fa-fast-backward:before{content:""}.fa-backward:before{content:""}.fa-play:before{content:""}.fa-pause:before{content:""}.fa-stop:before{content:""}.fa-forward:before{content:""}.fa-fast-forward:before{content:""}.fa-step-forward:before{content:""}.fa-eject:before{content:""}.fa-chevron-left:before{content:""}.fa-chevron-right:before{content:""}.fa-plus-circle:before{content:""}.fa-minus-circle:before{content:""}.fa-times-circle:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before{content:""}.fa-check-circle:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before{content:""}.fa-question-circle:before{content:""}.fa-info-circle:before{content:""}.fa-crosshairs:before{content:""}.fa-times-circle-o:before{content:""}.fa-check-circle-o:before{content:""}.fa-ban:before{content:""}.fa-arrow-left:before{content:""}.fa-arrow-right:before{content:""}.fa-arrow-up:before{content:""}.fa-arrow-down:before{content:""}.fa-mail-forward:before,.fa-share:before{content:""}.fa-expand:before{content:""}.fa-compress:before{content:""}.fa-plus:before{content:""}.fa-minus:before{content:""}.fa-asterisk:before{content:""}.fa-exclamation-circle:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.rst-content .admonition-title:before{content:""}.fa-gift:before{content:""}.fa-leaf:before{content:""}.fa-fire:before,.icon-fire:before{content:""}.fa-eye:before{content:""}.fa-eye-slash:before{content:""}.fa-warning:before,.fa-exclamation-triangle:before{content:""}.fa-plane:before{content:""}.fa-calendar:before{content:""}.fa-random:before{content:""}.fa-comment:before{content:""}.fa-magnet:before{content:""}.fa-chevron-up:before{content:""}.fa-chevron-down:before{content:""}.fa-retweet:before{content:""}.fa-shopping-cart:before{content:""}.fa-folder:before{content:""}.fa-folder-open:before{content:""}.fa-arrows-v:before{content:""}.fa-arrows-h:before{content:""}.fa-bar-chart-o:before,.fa-bar-chart:before{content:""}.fa-twitter-square:before{content:""}.fa-facebook-square:before{content:""}.fa-camera-retro:before{content:""}.fa-key:before{content:""}.fa-gears:before,.fa-cogs:before{content:""}.fa-comments:before{content:""}.fa-thumbs-o-up:before{content:""}.fa-thumbs-o-down:before{content:""}.fa-star-half:before{content:""}.fa-heart-o:before{content:""}.fa-sign-out:before{content:""}.fa-linkedin-square:before{content:""}.fa-thumb-tack:before{content:""}.fa-external-link:before{content:""}.fa-sign-in:before{content:""}.fa-trophy:before{content:""}.fa-github-square:before{content:""}.fa-upload:before{content:""}.fa-lemon-o:before{content:""}.fa-phone:before{content:""}.fa-square-o:before{content:""}.fa-bookmark-o:before{content:""}.fa-phone-square:before{content:""}.fa-twitter:before{content:""}.fa-facebook-f:before,.fa-facebook:before{content:""}.fa-github:before,.icon-github:before{content:""}.fa-unlock:before{content:""}.fa-credit-card:before{content:""}.fa-feed:before,.fa-rss:before{content:""}.fa-hdd-o:before{content:""}.fa-bullhorn:before{content:""}.fa-bell:before{content:""}.fa-certificate:before{content:""}.fa-hand-o-right:before{content:""}.fa-hand-o-left:before{content:""}.fa-hand-o-up:before{content:""}.fa-hand-o-down:before{content:""}.fa-arrow-circle-left:before,.icon-circle-arrow-left:before{content:""}.fa-arrow-circle-right:before,.icon-circle-arrow-right:before{content:""}.fa-arrow-circle-up:before{content:""}.fa-arrow-circle-down:before{content:""}.fa-globe:before{content:""}.fa-wrench:before{content:""}.fa-tasks:before{content:""}.fa-filter:before{content:""}.fa-briefcase:before{content:""}.fa-arrows-alt:before{content:""}.fa-group:before,.fa-users:before{content:""}.fa-chain:before,.fa-link:before,.icon-link:before{content:""}.fa-cloud:before{content:""}.fa-flask:before{content:""}.fa-cut:before,.fa-scissors:before{content:""}.fa-copy:before,.fa-files-o:before{content:""}.fa-paperclip:before{content:""}.fa-save:before,.fa-floppy-o:before{content:""}.fa-square:before{content:""}.fa-navicon:before,.fa-reorder:before,.fa-bars:before{content:""}.fa-list-ul:before{content:""}.fa-list-ol:before{content:""}.fa-strikethrough:before{content:""}.fa-underline:before{content:""}.fa-table:before{content:""}.fa-magic:before{content:""}.fa-truck:before{content:""}.fa-pinterest:before{content:""}.fa-pinterest-square:before{content:""}.fa-google-plus-square:before{content:""}.fa-google-plus:before{content:""}.fa-money:before{content:""}.fa-caret-down:before,.wy-dropdown .caret:before,.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.fa-caret-left:before{content:""}.fa-caret-right:before{content:""}.fa-columns:before{content:""}.fa-unsorted:before,.fa-sort:before{content:""}.fa-sort-down:before,.fa-sort-desc:before{content:""}.fa-sort-up:before,.fa-sort-asc:before{content:""}.fa-envelope:before{content:""}.fa-linkedin:before{content:""}.fa-rotate-left:before,.fa-undo:before{content:""}.fa-legal:before,.fa-gavel:before{content:""}.fa-dashboard:before,.fa-tachometer:before{content:""}.fa-comment-o:before{content:""}.fa-comments-o:before{content:""}.fa-flash:before,.fa-bolt:before{content:""}.fa-sitemap:before{content:""}.fa-umbrella:before{content:""}.fa-paste:before,.fa-clipboard:before{content:""}.fa-lightbulb-o:before{content:""}.fa-exchange:before{content:""}.fa-cloud-download:before{content:""}.fa-cloud-upload:before{content:""}.fa-user-md:before{content:""}.fa-stethoscope:before{content:""}.fa-suitcase:before{content:""}.fa-bell-o:before{content:""}.fa-coffee:before{content:""}.fa-cutlery:before{content:""}.fa-file-text-o:before{content:""}.fa-building-o:before{content:""}.fa-hospital-o:before{content:""}.fa-ambulance:before{content:""}.fa-medkit:before{content:""}.fa-fighter-jet:before{content:""}.fa-beer:before{content:""}.fa-h-square:before{content:""}.fa-plus-square:before{content:""}.fa-angle-double-left:before{content:""}.fa-angle-double-right:before{content:""}.fa-angle-double-up:before{content:""}.fa-angle-double-down:before{content:""}.fa-angle-left:before{content:""}.fa-angle-right:before{content:""}.fa-angle-up:before{content:""}.fa-angle-down:before{content:""}.fa-desktop:before{content:""}.fa-laptop:before{content:""}.fa-tablet:before{content:""}.fa-mobile-phone:before,.fa-mobile:before{content:""}.fa-circle-o:before{content:""}.fa-quote-left:before{content:""}.fa-quote-right:before{content:""}.fa-spinner:before{content:""}.fa-circle:before{content:""}.fa-mail-reply:before,.fa-reply:before{content:""}.fa-github-alt:before{content:""}.fa-folder-o:before{content:""}.fa-folder-open-o:before{content:""}.fa-smile-o:before{content:""}.fa-frown-o:before{content:""}.fa-meh-o:before{content:""}.fa-gamepad:before{content:""}.fa-keyboard-o:before{content:""}.fa-flag-o:before{content:""}.fa-flag-checkered:before{content:""}.fa-terminal:before{content:""}.fa-code:before{content:""}.fa-mail-reply-all:before,.fa-reply-all:before{content:""}.fa-star-half-empty:before,.fa-star-half-full:before,.fa-star-half-o:before{content:""}.fa-location-arrow:before{content:""}.fa-crop:before{content:""}.fa-code-fork:before{content:""}.fa-unlink:before,.fa-chain-broken:before{content:""}.fa-question:before{content:""}.fa-info:before{content:""}.fa-exclamation:before{content:""}.fa-superscript:before{content:""}.fa-subscript:before{content:""}.fa-eraser:before{content:""}.fa-puzzle-piece:before{content:""}.fa-microphone:before{content:""}.fa-microphone-slash:before{content:""}.fa-shield:before{content:""}.fa-calendar-o:before{content:""}.fa-fire-extinguisher:before{content:""}.fa-rocket:before{content:""}.fa-maxcdn:before{content:""}.fa-chevron-circle-left:before{content:""}.fa-chevron-circle-right:before{content:""}.fa-chevron-circle-up:before{content:""}.fa-chevron-circle-down:before{content:""}.fa-html5:before{content:""}.fa-css3:before{content:""}.fa-anchor:before{content:""}.fa-unlock-alt:before{content:""}.fa-bullseye:before{content:""}.fa-ellipsis-h:before{content:""}.fa-ellipsis-v:before{content:""}.fa-rss-square:before{content:""}.fa-play-circle:before{content:""}.fa-ticket:before{content:""}.fa-minus-square:before{content:""}.fa-minus-square-o:before,.wy-menu-vertical li.on a span.toctree-expand:before,.wy-menu-vertical li.current>a span.toctree-expand:before{content:""}.fa-level-up:before{content:""}.fa-level-down:before{content:""}.fa-check-square:before{content:""}.fa-pencil-square:before{content:""}.fa-external-link-square:before{content:""}.fa-share-square:before{content:""}.fa-compass:before{content:""}.fa-toggle-down:before,.fa-caret-square-o-down:before{content:""}.fa-toggle-up:before,.fa-caret-square-o-up:before{content:""}.fa-toggle-right:before,.fa-caret-square-o-right:before{content:""}.fa-euro:before,.fa-eur:before{content:""}.fa-gbp:before{content:""}.fa-dollar:before,.fa-usd:before{content:""}.fa-rupee:before,.fa-inr:before{content:""}.fa-cny:before,.fa-rmb:before,.fa-yen:before,.fa-jpy:before{content:""}.fa-ruble:before,.fa-rouble:before,.fa-rub:before{content:""}.fa-won:before,.fa-krw:before{content:""}.fa-bitcoin:before,.fa-btc:before{content:""}.fa-file:before{content:""}.fa-file-text:before{content:""}.fa-sort-alpha-asc:before{content:""}.fa-sort-alpha-desc:before{content:""}.fa-sort-amount-asc:before{content:""}.fa-sort-amount-desc:before{content:""}.fa-sort-numeric-asc:before{content:""}.fa-sort-numeric-desc:before{content:""}.fa-thumbs-up:before{content:""}.fa-thumbs-down:before{content:""}.fa-youtube-square:before{content:""}.fa-youtube:before{content:""}.fa-xing:before{content:""}.fa-xing-square:before{content:""}.fa-youtube-play:before{content:""}.fa-dropbox:before{content:""}.fa-stack-overflow:before{content:""}.fa-instagram:before{content:""}.fa-flickr:before{content:""}.fa-adn:before{content:""}.fa-bitbucket:before,.icon-bitbucket:before{content:""}.fa-bitbucket-square:before{content:""}.fa-tumblr:before{content:""}.fa-tumblr-square:before{content:""}.fa-long-arrow-down:before{content:""}.fa-long-arrow-up:before{content:""}.fa-long-arrow-left:before{content:""}.fa-long-arrow-right:before{content:""}.fa-apple:before{content:""}.fa-windows:before{content:""}.fa-android:before{content:""}.fa-linux:before{content:""}.fa-dribbble:before{content:""}.fa-skype:before{content:""}.fa-foursquare:before{content:""}.fa-trello:before{content:""}.fa-female:before{content:""}.fa-male:before{content:""}.fa-gittip:before,.fa-gratipay:before{content:""}.fa-sun-o:before{content:""}.fa-moon-o:before{content:""}.fa-archive:before{content:""}.fa-bug:before{content:""}.fa-vk:before{content:""}.fa-weibo:before{content:""}.fa-renren:before{content:""}.fa-pagelines:before{content:""}.fa-stack-exchange:before{content:""}.fa-arrow-circle-o-right:before{content:""}.fa-arrow-circle-o-left:before{content:""}.fa-toggle-left:before,.fa-caret-square-o-left:before{content:""}.fa-dot-circle-o:before{content:""}.fa-wheelchair:before{content:""}.fa-vimeo-square:before{content:""}.fa-turkish-lira:before,.fa-try:before{content:""}.fa-plus-square-o:before,.wy-menu-vertical li span.toctree-expand:before{content:""}.fa-space-shuttle:before{content:""}.fa-slack:before{content:""}.fa-envelope-square:before{content:""}.fa-wordpress:before{content:""}.fa-openid:before{content:""}.fa-institution:before,.fa-bank:before,.fa-university:before{content:""}.fa-mortar-board:before,.fa-graduation-cap:before{content:""}.fa-yahoo:before{content:""}.fa-google:before{content:""}.fa-reddit:before{content:""}.fa-reddit-square:before{content:""}.fa-stumbleupon-circle:before{content:""}.fa-stumbleupon:before{content:""}.fa-delicious:before{content:""}.fa-digg:before{content:""}.fa-pied-piper-pp:before{content:""}.fa-pied-piper-alt:before{content:""}.fa-drupal:before{content:""}.fa-joomla:before{content:""}.fa-language:before{content:""}.fa-fax:before{content:""}.fa-building:before{content:""}.fa-child:before{content:""}.fa-paw:before{content:""}.fa-spoon:before{content:""}.fa-cube:before{content:""}.fa-cubes:before{content:""}.fa-behance:before{content:""}.fa-behance-square:before{content:""}.fa-steam:before{content:""}.fa-steam-square:before{content:""}.fa-recycle:before{content:""}.fa-automobile:before,.fa-car:before{content:""}.fa-cab:before,.fa-taxi:before{content:""}.fa-tree:before{content:""}.fa-spotify:before{content:""}.fa-deviantart:before{content:""}.fa-soundcloud:before{content:""}.fa-database:before{content:""}.fa-file-pdf-o:before{content:""}.fa-file-word-o:before{content:""}.fa-file-excel-o:before{content:""}.fa-file-powerpoint-o:before{content:""}.fa-file-photo-o:before,.fa-file-picture-o:before,.fa-file-image-o:before{content:""}.fa-file-zip-o:before,.fa-file-archive-o:before{content:""}.fa-file-sound-o:before,.fa-file-audio-o:before{content:""}.fa-file-movie-o:before,.fa-file-video-o:before{content:""}.fa-file-code-o:before{content:""}.fa-vine:before{content:""}.fa-codepen:before{content:""}.fa-jsfiddle:before{content:""}.fa-life-bouy:before,.fa-life-buoy:before,.fa-life-saver:before,.fa-support:before,.fa-life-ring:before{content:""}.fa-circle-o-notch:before{content:""}.fa-ra:before,.fa-resistance:before,.fa-rebel:before{content:""}.fa-ge:before,.fa-empire:before{content:""}.fa-git-square:before{content:""}.fa-git:before{content:""}.fa-y-combinator-square:before,.fa-yc-square:before,.fa-hacker-news:before{content:""}.fa-tencent-weibo:before{content:""}.fa-qq:before{content:""}.fa-wechat:before,.fa-weixin:before{content:""}.fa-send:before,.fa-paper-plane:before{content:""}.fa-send-o:before,.fa-paper-plane-o:before{content:""}.fa-history:before{content:""}.fa-circle-thin:before{content:""}.fa-header:before{content:""}.fa-paragraph:before{content:""}.fa-sliders:before{content:""}.fa-share-alt:before{content:""}.fa-share-alt-square:before{content:""}.fa-bomb:before{content:""}.fa-soccer-ball-o:before,.fa-futbol-o:before{content:""}.fa-tty:before{content:""}.fa-binoculars:before{content:""}.fa-plug:before{content:""}.fa-slideshare:before{content:""}.fa-twitch:before{content:""}.fa-yelp:before{content:""}.fa-newspaper-o:before{content:""}.fa-wifi:before{content:""}.fa-calculator:before{content:""}.fa-paypal:before{content:""}.fa-google-wallet:before{content:""}.fa-cc-visa:before{content:""}.fa-cc-mastercard:before{content:""}.fa-cc-discover:before{content:""}.fa-cc-amex:before{content:""}.fa-cc-paypal:before{content:""}.fa-cc-stripe:before{content:""}.fa-bell-slash:before{content:""}.fa-bell-slash-o:before{content:""}.fa-trash:before{content:""}.fa-copyright:before{content:""}.fa-at:before{content:""}.fa-eyedropper:before{content:""}.fa-paint-brush:before{content:""}.fa-birthday-cake:before{content:""}.fa-area-chart:before{content:""}.fa-pie-chart:before{content:""}.fa-line-chart:before{content:""}.fa-lastfm:before{content:""}.fa-lastfm-square:before{content:""}.fa-toggle-off:before{content:""}.fa-toggle-on:before{content:""}.fa-bicycle:before{content:""}.fa-bus:before{content:""}.fa-ioxhost:before{content:""}.fa-angellist:before{content:""}.fa-cc:before{content:""}.fa-shekel:before,.fa-sheqel:before,.fa-ils:before{content:""}.fa-meanpath:before{content:""}.fa-buysellads:before{content:""}.fa-connectdevelop:before{content:""}.fa-dashcube:before{content:""}.fa-forumbee:before{content:""}.fa-leanpub:before{content:""}.fa-sellsy:before{content:""}.fa-shirtsinbulk:before{content:""}.fa-simplybuilt:before{content:""}.fa-skyatlas:before{content:""}.fa-cart-plus:before{content:""}.fa-cart-arrow-down:before{content:""}.fa-diamond:before{content:""}.fa-ship:before{content:""}.fa-user-secret:before{content:""}.fa-motorcycle:before{content:""}.fa-street-view:before{content:""}.fa-heartbeat:before{content:""}.fa-venus:before{content:""}.fa-mars:before{content:""}.fa-mercury:before{content:""}.fa-intersex:before,.fa-transgender:before{content:""}.fa-transgender-alt:before{content:""}.fa-venus-double:before{content:""}.fa-mars-double:before{content:""}.fa-venus-mars:before{content:""}.fa-mars-stroke:before{content:""}.fa-mars-stroke-v:before{content:""}.fa-mars-stroke-h:before{content:""}.fa-neuter:before{content:""}.fa-genderless:before{content:""}.fa-facebook-official:before{content:""}.fa-pinterest-p:before{content:""}.fa-whatsapp:before{content:""}.fa-server:before{content:""}.fa-user-plus:before{content:""}.fa-user-times:before{content:""}.fa-hotel:before,.fa-bed:before{content:""}.fa-viacoin:before{content:""}.fa-train:before{content:""}.fa-subway:before{content:""}.fa-medium:before{content:""}.fa-yc:before,.fa-y-combinator:before{content:""}.fa-optin-monster:before{content:""}.fa-opencart:before{content:""}.fa-expeditedssl:before{content:""}.fa-battery-4:before,.fa-battery:before,.fa-battery-full:before{content:""}.fa-battery-3:before,.fa-battery-three-quarters:before{content:""}.fa-battery-2:before,.fa-battery-half:before{content:""}.fa-battery-1:before,.fa-battery-quarter:before{content:""}.fa-battery-0:before,.fa-battery-empty:before{content:""}.fa-mouse-pointer:before{content:""}.fa-i-cursor:before{content:""}.fa-object-group:before{content:""}.fa-object-ungroup:before{content:""}.fa-sticky-note:before{content:""}.fa-sticky-note-o:before{content:""}.fa-cc-jcb:before{content:""}.fa-cc-diners-club:before{content:""}.fa-clone:before{content:""}.fa-balance-scale:before{content:""}.fa-hourglass-o:before{content:""}.fa-hourglass-1:before,.fa-hourglass-start:before{content:""}.fa-hourglass-2:before,.fa-hourglass-half:before{content:""}.fa-hourglass-3:before,.fa-hourglass-end:before{content:""}.fa-hourglass:before{content:""}.fa-hand-grab-o:before,.fa-hand-rock-o:before{content:""}.fa-hand-stop-o:before,.fa-hand-paper-o:before{content:""}.fa-hand-scissors-o:before{content:""}.fa-hand-lizard-o:before{content:""}.fa-hand-spock-o:before{content:""}.fa-hand-pointer-o:before{content:""}.fa-hand-peace-o:before{content:""}.fa-trademark:before{content:""}.fa-registered:before{content:""}.fa-creative-commons:before{content:""}.fa-gg:before{content:""}.fa-gg-circle:before{content:""}.fa-tripadvisor:before{content:""}.fa-odnoklassniki:before{content:""}.fa-odnoklassniki-square:before{content:""}.fa-get-pocket:before{content:""}.fa-wikipedia-w:before{content:""}.fa-safari:before{content:""}.fa-chrome:before{content:""}.fa-firefox:before{content:""}.fa-opera:before{content:""}.fa-internet-explorer:before{content:""}.fa-tv:before,.fa-television:before{content:""}.fa-contao:before{content:""}.fa-500px:before{content:""}.fa-amazon:before{content:""}.fa-calendar-plus-o:before{content:""}.fa-calendar-minus-o:before{content:""}.fa-calendar-times-o:before{content:""}.fa-calendar-check-o:before{content:""}.fa-industry:before{content:""}.fa-map-pin:before{content:""}.fa-map-signs:before{content:""}.fa-map-o:before{content:""}.fa-map:before{content:""}.fa-commenting:before{content:""}.fa-commenting-o:before{content:""}.fa-houzz:before{content:""}.fa-vimeo:before{content:""}.fa-black-tie:before{content:""}.fa-fonticons:before{content:""}.fa-reddit-alien:before{content:""}.fa-edge:before{content:""}.fa-credit-card-alt:before{content:""}.fa-codiepie:before{content:""}.fa-modx:before{content:""}.fa-fort-awesome:before{content:""}.fa-usb:before{content:""}.fa-product-hunt:before{content:""}.fa-mixcloud:before{content:""}.fa-scribd:before{content:""}.fa-pause-circle:before{content:""}.fa-pause-circle-o:before{content:""}.fa-stop-circle:before{content:""}.fa-stop-circle-o:before{content:""}.fa-shopping-bag:before{content:""}.fa-shopping-basket:before{content:""}.fa-hashtag:before{content:""}.fa-bluetooth:before{content:""}.fa-bluetooth-b:before{content:""}.fa-percent:before{content:""}.fa-gitlab:before,.icon-gitlab:before{content:""}.fa-wpbeginner:before{content:""}.fa-wpforms:before{content:""}.fa-envira:before{content:""}.fa-universal-access:before{content:""}.fa-wheelchair-alt:before{content:""}.fa-question-circle-o:before{content:""}.fa-blind:before{content:""}.fa-audio-description:before{content:""}.fa-volume-control-phone:before{content:""}.fa-braille:before{content:""}.fa-assistive-listening-systems:before{content:""}.fa-asl-interpreting:before,.fa-american-sign-language-interpreting:before{content:""}.fa-deafness:before,.fa-hard-of-hearing:before,.fa-deaf:before{content:""}.fa-glide:before{content:""}.fa-glide-g:before{content:""}.fa-signing:before,.fa-sign-language:before{content:""}.fa-low-vision:before{content:""}.fa-viadeo:before{content:""}.fa-viadeo-square:before{content:""}.fa-snapchat:before{content:""}.fa-snapchat-ghost:before{content:""}.fa-snapchat-square:before{content:""}.fa-pied-piper:before{content:""}.fa-first-order:before{content:""}.fa-yoast:before{content:""}.fa-themeisle:before{content:""}.fa-google-plus-circle:before,.fa-google-plus-official:before{content:""}.fa-fa:before,.fa-font-awesome:before{content:""}.fa-handshake-o:before{content:""}.fa-envelope-open:before{content:""}.fa-envelope-open-o:before{content:""}.fa-linode:before{content:""}.fa-address-book:before{content:""}.fa-address-book-o:before{content:""}.fa-vcard:before,.fa-address-card:before{content:""}.fa-vcard-o:before,.fa-address-card-o:before{content:""}.fa-user-circle:before{content:""}.fa-user-circle-o:before{content:""}.fa-user-o:before{content:""}.fa-id-badge:before{content:""}.fa-drivers-license:before,.fa-id-card:before{content:""}.fa-drivers-license-o:before,.fa-id-card-o:before{content:""}.fa-quora:before{content:""}.fa-free-code-camp:before{content:""}.fa-telegram:before{content:""}.fa-thermometer-4:before,.fa-thermometer:before,.fa-thermometer-full:before{content:""}.fa-thermometer-3:before,.fa-thermometer-three-quarters:before{content:""}.fa-thermometer-2:before,.fa-thermometer-half:before{content:""}.fa-thermometer-1:before,.fa-thermometer-quarter:before{content:""}.fa-thermometer-0:before,.fa-thermometer-empty:before{content:""}.fa-shower:before{content:""}.fa-bathtub:before,.fa-s15:before,.fa-bath:before{content:""}.fa-podcast:before{content:""}.fa-window-maximize:before{content:""}.fa-window-minimize:before{content:""}.fa-window-restore:before{content:""}.fa-times-rectangle:before,.fa-window-close:before{content:""}.fa-times-rectangle-o:before,.fa-window-close-o:before{content:""}.fa-bandcamp:before{content:""}.fa-grav:before{content:""}.fa-etsy:before{content:""}.fa-imdb:before{content:""}.fa-ravelry:before{content:""}.fa-eercast:before{content:""}.fa-microchip:before{content:""}.fa-snowflake-o:before{content:""}.fa-superpowers:before{content:""}.fa-wpexplorer:before{content:""}.fa-meetup:before{content:""}.sr-only{position:absolute;width:1px;height:1px;padding:0;margin:-1px;overflow:hidden;clip:rect(0, 0, 0, 0);border:0}.sr-only-focusable:active,.sr-only-focusable:focus{position:static;width:auto;height:auto;margin:0;overflow:visible;clip:auto}.fa,.wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.current>a span.toctree-expand,.rst-content .admonition-title,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content dl dt .headerlink,.rst-content p.caption .headerlink,.rst-content table>caption .headerlink,.rst-content tt.download span:first-child,.rst-content code.download span:first-child,.icon,.wy-dropdown .caret,.wy-inline-validate.wy-inline-validate-success .wy-input-context,.wy-inline-validate.wy-inline-validate-danger .wy-input-context,.wy-inline-validate.wy-inline-validate-warning .wy-input-context,.wy-inline-validate.wy-inline-validate-info .wy-input-context{font-family:inherit}.fa:before,.wy-menu-vertical li span.toctree-expand:before,.wy-menu-vertical li.on a span.toctree-expand:before,.wy-menu-vertical li.current>a span.toctree-expand:before,.rst-content .admonition-title:before,.rst-content h1 .headerlink:before,.rst-content h2 .headerlink:before,.rst-content h3 .headerlink:before,.rst-content h4 .headerlink:before,.rst-content h5 .headerlink:before,.rst-content h6 .headerlink:before,.rst-content dl dt .headerlink:before,.rst-content p.caption .headerlink:before,.rst-content table>caption .headerlink:before,.rst-content tt.download span:first-child:before,.rst-content code.download span:first-child:before,.icon:before,.wy-dropdown .caret:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before{font-family:"FontAwesome";display:inline-block;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa,a .wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li a span.toctree-expand,.wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.current>a span.toctree-expand,a .rst-content .admonition-title,.rst-content a .admonition-title,a .rst-content h1 .headerlink,.rst-content h1 a .headerlink,a .rst-content h2 .headerlink,.rst-content h2 a .headerlink,a .rst-content h3 .headerlink,.rst-content h3 a .headerlink,a .rst-content h4 .headerlink,.rst-content h4 a .headerlink,a .rst-content h5 .headerlink,.rst-content h5 a .headerlink,a .rst-content h6 .headerlink,.rst-content h6 a .headerlink,a .rst-content dl dt .headerlink,.rst-content dl dt a .headerlink,a .rst-content p.caption .headerlink,.rst-content p.caption a .headerlink,a .rst-content table>caption .headerlink,.rst-content table>caption a .headerlink,a .rst-content tt.download span:first-child,.rst-content tt.download a span:first-child,a .rst-content code.download span:first-child,.rst-content code.download a span:first-child,a .icon{display:inline-block;text-decoration:inherit}.btn .fa,.btn .wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li .btn span.toctree-expand,.btn .wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.on a .btn span.toctree-expand,.btn .wy-menu-vertical li.current>a span.toctree-expand,.wy-menu-vertical li.current>a .btn span.toctree-expand,.btn .rst-content .admonition-title,.rst-content .btn .admonition-title,.btn .rst-content h1 .headerlink,.rst-content h1 .btn .headerlink,.btn .rst-content h2 .headerlink,.rst-content h2 .btn .headerlink,.btn .rst-content h3 .headerlink,.rst-content h3 .btn .headerlink,.btn .rst-content h4 .headerlink,.rst-content h4 .btn .headerlink,.btn .rst-content h5 .headerlink,.rst-content h5 .btn .headerlink,.btn .rst-content h6 .headerlink,.rst-content h6 .btn .headerlink,.btn .rst-content dl dt .headerlink,.rst-content dl dt .btn .headerlink,.btn .rst-content p.caption .headerlink,.rst-content p.caption .btn .headerlink,.btn .rst-content table>caption .headerlink,.rst-content table>caption .btn .headerlink,.btn .rst-content tt.download span:first-child,.rst-content tt.download .btn span:first-child,.btn .rst-content code.download span:first-child,.rst-content code.download .btn span:first-child,.btn .icon,.nav .fa,.nav .wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li .nav span.toctree-expand,.nav .wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.on a .nav span.toctree-expand,.nav .wy-menu-vertical li.current>a span.toctree-expand,.wy-menu-vertical li.current>a .nav span.toctree-expand,.nav .rst-content .admonition-title,.rst-content .nav .admonition-title,.nav .rst-content h1 .headerlink,.rst-content h1 .nav .headerlink,.nav .rst-content h2 .headerlink,.rst-content h2 .nav .headerlink,.nav .rst-content h3 .headerlink,.rst-content h3 .nav .headerlink,.nav .rst-content h4 .headerlink,.rst-content h4 .nav .headerlink,.nav .rst-content h5 .headerlink,.rst-content h5 .nav .headerlink,.nav .rst-content h6 .headerlink,.rst-content h6 .nav .headerlink,.nav .rst-content dl dt .headerlink,.rst-content dl dt .nav .headerlink,.nav .rst-content p.caption .headerlink,.rst-content p.caption .nav .headerlink,.nav .rst-content table>caption .headerlink,.rst-content table>caption .nav .headerlink,.nav .rst-content tt.download span:first-child,.rst-content tt.download .nav span:first-child,.nav .rst-content code.download span:first-child,.rst-content code.download .nav span:first-child,.nav .icon{display:inline}.btn .fa.fa-large,.btn .wy-menu-vertical li span.fa-large.toctree-expand,.wy-menu-vertical li .btn span.fa-large.toctree-expand,.btn .rst-content .fa-large.admonition-title,.rst-content .btn .fa-large.admonition-title,.btn .rst-content h1 .fa-large.headerlink,.rst-content h1 .btn .fa-large.headerlink,.btn .rst-content h2 .fa-large.headerlink,.rst-content h2 .btn .fa-large.headerlink,.btn .rst-content h3 .fa-large.headerlink,.rst-content h3 .btn .fa-large.headerlink,.btn .rst-content h4 .fa-large.headerlink,.rst-content h4 .btn .fa-large.headerlink,.btn .rst-content h5 .fa-large.headerlink,.rst-content h5 .btn .fa-large.headerlink,.btn .rst-content h6 .fa-large.headerlink,.rst-content h6 .btn .fa-large.headerlink,.btn .rst-content dl dt .fa-large.headerlink,.rst-content dl dt .btn .fa-large.headerlink,.btn .rst-content p.caption .fa-large.headerlink,.rst-content p.caption .btn .fa-large.headerlink,.btn .rst-content table>caption .fa-large.headerlink,.rst-content table>caption .btn .fa-large.headerlink,.btn .rst-content tt.download span.fa-large:first-child,.rst-content tt.download .btn span.fa-large:first-child,.btn .rst-content code.download span.fa-large:first-child,.rst-content code.download .btn span.fa-large:first-child,.btn .fa-large.icon,.nav .fa.fa-large,.nav .wy-menu-vertical li span.fa-large.toctree-expand,.wy-menu-vertical li .nav span.fa-large.toctree-expand,.nav .rst-content .fa-large.admonition-title,.rst-content .nav .fa-large.admonition-title,.nav .rst-content h1 .fa-large.headerlink,.rst-content h1 .nav .fa-large.headerlink,.nav .rst-content h2 .fa-large.headerlink,.rst-content h2 .nav .fa-large.headerlink,.nav .rst-content h3 .fa-large.headerlink,.rst-content h3 .nav .fa-large.headerlink,.nav .rst-content h4 .fa-large.headerlink,.rst-content h4 .nav .fa-large.headerlink,.nav .rst-content h5 .fa-large.headerlink,.rst-content h5 .nav .fa-large.headerlink,.nav .rst-content h6 .fa-large.headerlink,.rst-content h6 .nav .fa-large.headerlink,.nav .rst-content dl dt .fa-large.headerlink,.rst-content dl dt .nav .fa-large.headerlink,.nav .rst-content p.caption .fa-large.headerlink,.rst-content p.caption .nav .fa-large.headerlink,.nav .rst-content table>caption .fa-large.headerlink,.rst-content table>caption .nav .fa-large.headerlink,.nav .rst-content tt.download span.fa-large:first-child,.rst-content tt.download .nav span.fa-large:first-child,.nav .rst-content code.download span.fa-large:first-child,.rst-content code.download .nav span.fa-large:first-child,.nav .fa-large.icon{line-height:.9em}.btn .fa.fa-spin,.btn .wy-menu-vertical li span.fa-spin.toctree-expand,.wy-menu-vertical li .btn span.fa-spin.toctree-expand,.btn .rst-content .fa-spin.admonition-title,.rst-content .btn .fa-spin.admonition-title,.btn .rst-content h1 .fa-spin.headerlink,.rst-content h1 .btn .fa-spin.headerlink,.btn .rst-content h2 .fa-spin.headerlink,.rst-content h2 .btn .fa-spin.headerlink,.btn .rst-content h3 .fa-spin.headerlink,.rst-content h3 .btn .fa-spin.headerlink,.btn .rst-content h4 .fa-spin.headerlink,.rst-content h4 .btn .fa-spin.headerlink,.btn .rst-content h5 .fa-spin.headerlink,.rst-content h5 .btn .fa-spin.headerlink,.btn .rst-content h6 .fa-spin.headerlink,.rst-content h6 .btn .fa-spin.headerlink,.btn .rst-content dl dt .fa-spin.headerlink,.rst-content dl dt .btn .fa-spin.headerlink,.btn .rst-content p.caption .fa-spin.headerlink,.rst-content p.caption .btn .fa-spin.headerlink,.btn .rst-content table>caption .fa-spin.headerlink,.rst-content table>caption .btn .fa-spin.headerlink,.btn .rst-content tt.download span.fa-spin:first-child,.rst-content tt.download .btn span.fa-spin:first-child,.btn .rst-content code.download span.fa-spin:first-child,.rst-content code.download .btn span.fa-spin:first-child,.btn .fa-spin.icon,.nav .fa.fa-spin,.nav .wy-menu-vertical li span.fa-spin.toctree-expand,.wy-menu-vertical li .nav span.fa-spin.toctree-expand,.nav .rst-content .fa-spin.admonition-title,.rst-content .nav .fa-spin.admonition-title,.nav .rst-content h1 .fa-spin.headerlink,.rst-content h1 .nav .fa-spin.headerlink,.nav .rst-content h2 .fa-spin.headerlink,.rst-content h2 .nav .fa-spin.headerlink,.nav .rst-content h3 .fa-spin.headerlink,.rst-content h3 .nav .fa-spin.headerlink,.nav .rst-content h4 .fa-spin.headerlink,.rst-content h4 .nav .fa-spin.headerlink,.nav .rst-content h5 .fa-spin.headerlink,.rst-content h5 .nav .fa-spin.headerlink,.nav .rst-content h6 .fa-spin.headerlink,.rst-content h6 .nav .fa-spin.headerlink,.nav .rst-content dl dt .fa-spin.headerlink,.rst-content dl dt .nav .fa-spin.headerlink,.nav .rst-content p.caption .fa-spin.headerlink,.rst-content p.caption .nav .fa-spin.headerlink,.nav .rst-content table>caption .fa-spin.headerlink,.rst-content table>caption .nav .fa-spin.headerlink,.nav .rst-content tt.download span.fa-spin:first-child,.rst-content tt.download .nav span.fa-spin:first-child,.nav .rst-content code.download span.fa-spin:first-child,.rst-content code.download .nav span.fa-spin:first-child,.nav .fa-spin.icon{display:inline-block}.btn.fa:before,.wy-menu-vertical li span.btn.toctree-expand:before,.rst-content .btn.admonition-title:before,.rst-content h1 .btn.headerlink:before,.rst-content h2 .btn.headerlink:before,.rst-content h3 .btn.headerlink:before,.rst-content h4 .btn.headerlink:before,.rst-content h5 .btn.headerlink:before,.rst-content h6 .btn.headerlink:before,.rst-content dl dt .btn.headerlink:before,.rst-content p.caption .btn.headerlink:before,.rst-content table>caption .btn.headerlink:before,.rst-content tt.download span.btn:first-child:before,.rst-content code.download span.btn:first-child:before,.btn.icon:before{opacity:.5;-webkit-transition:opacity .05s ease-in;-moz-transition:opacity .05s ease-in;transition:opacity .05s ease-in}.btn.fa:hover:before,.wy-menu-vertical li span.btn.toctree-expand:hover:before,.rst-content .btn.admonition-title:hover:before,.rst-content h1 .btn.headerlink:hover:before,.rst-content h2 .btn.headerlink:hover:before,.rst-content h3 .btn.headerlink:hover:before,.rst-content h4 .btn.headerlink:hover:before,.rst-content h5 .btn.headerlink:hover:before,.rst-content h6 .btn.headerlink:hover:before,.rst-content dl dt .btn.headerlink:hover:before,.rst-content p.caption .btn.headerlink:hover:before,.rst-content table>caption .btn.headerlink:hover:before,.rst-content tt.download span.btn:first-child:hover:before,.rst-content code.download span.btn:first-child:hover:before,.btn.icon:hover:before{opacity:1}.btn-mini .fa:before,.btn-mini .wy-menu-vertical li span.toctree-expand:before,.wy-menu-vertical li .btn-mini span.toctree-expand:before,.btn-mini .rst-content .admonition-title:before,.rst-content .btn-mini .admonition-title:before,.btn-mini .rst-content h1 .headerlink:before,.rst-content h1 .btn-mini .headerlink:before,.btn-mini .rst-content h2 .headerlink:before,.rst-content h2 .btn-mini .headerlink:before,.btn-mini .rst-content h3 .headerlink:before,.rst-content h3 .btn-mini .headerlink:before,.btn-mini .rst-content h4 .headerlink:before,.rst-content h4 .btn-mini .headerlink:before,.btn-mini .rst-content h5 .headerlink:before,.rst-content h5 .btn-mini .headerlink:before,.btn-mini .rst-content h6 .headerlink:before,.rst-content h6 .btn-mini .headerlink:before,.btn-mini .rst-content dl dt .headerlink:before,.rst-content dl dt .btn-mini .headerlink:before,.btn-mini .rst-content p.caption .headerlink:before,.rst-content p.caption .btn-mini .headerlink:before,.btn-mini .rst-content table>caption .headerlink:before,.rst-content table>caption .btn-mini .headerlink:before,.btn-mini .rst-content tt.download span:first-child:before,.rst-content tt.download .btn-mini span:first-child:before,.btn-mini .rst-content code.download span:first-child:before,.rst-content code.download .btn-mini span:first-child:before,.btn-mini .icon:before{font-size:14px;vertical-align:-15%}.wy-alert,.rst-content .note,.rst-content .attention,.rst-content .caution,.rst-content .danger,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .tip,.rst-content .warning,.rst-content .seealso,.rst-content .admonition-todo,.rst-content .admonition{padding:12px;line-height:24px;margin-bottom:24px;background:#e7f2fa}.wy-alert-title,.rst-content .admonition-title{color:#fff;font-weight:bold;display:block;color:#fff;background:#6ab0de;margin:-12px;padding:6px 12px;margin-bottom:12px}.wy-alert.wy-alert-danger,.rst-content .wy-alert-danger.note,.rst-content .wy-alert-danger.attention,.rst-content .wy-alert-danger.caution,.rst-content .danger,.rst-content .error,.rst-content .wy-alert-danger.hint,.rst-content .wy-alert-danger.important,.rst-content .wy-alert-danger.tip,.rst-content .wy-alert-danger.warning,.rst-content .wy-alert-danger.seealso,.rst-content .wy-alert-danger.admonition-todo,.rst-content .wy-alert-danger.admonition{background:#fdf3f2}.wy-alert.wy-alert-danger .wy-alert-title,.rst-content .wy-alert-danger.note .wy-alert-title,.rst-content .wy-alert-danger.attention .wy-alert-title,.rst-content .wy-alert-danger.caution .wy-alert-title,.rst-content .danger .wy-alert-title,.rst-content .error .wy-alert-title,.rst-content .wy-alert-danger.hint .wy-alert-title,.rst-content .wy-alert-danger.important .wy-alert-title,.rst-content .wy-alert-danger.tip .wy-alert-title,.rst-content .wy-alert-danger.warning .wy-alert-title,.rst-content .wy-alert-danger.seealso .wy-alert-title,.rst-content .wy-alert-danger.admonition-todo .wy-alert-title,.rst-content .wy-alert-danger.admonition .wy-alert-title,.wy-alert.wy-alert-danger .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-danger .admonition-title,.rst-content .wy-alert-danger.note .admonition-title,.rst-content .wy-alert-danger.attention .admonition-title,.rst-content .wy-alert-danger.caution .admonition-title,.rst-content .danger .admonition-title,.rst-content .error .admonition-title,.rst-content .wy-alert-danger.hint .admonition-title,.rst-content .wy-alert-danger.important .admonition-title,.rst-content .wy-alert-danger.tip .admonition-title,.rst-content .wy-alert-danger.warning .admonition-title,.rst-content .wy-alert-danger.seealso .admonition-title,.rst-content .wy-alert-danger.admonition-todo .admonition-title,.rst-content .wy-alert-danger.admonition .admonition-title{background:#f29f97}.wy-alert.wy-alert-warning,.rst-content .wy-alert-warning.note,.rst-content .attention,.rst-content .caution,.rst-content .wy-alert-warning.danger,.rst-content .wy-alert-warning.error,.rst-content .wy-alert-warning.hint,.rst-content .wy-alert-warning.important,.rst-content .wy-alert-warning.tip,.rst-content .warning,.rst-content .wy-alert-warning.seealso,.rst-content .admonition-todo,.rst-content .wy-alert-warning.admonition{background:#ffedcc}.wy-alert.wy-alert-warning .wy-alert-title,.rst-content .wy-alert-warning.note .wy-alert-title,.rst-content .attention .wy-alert-title,.rst-content .caution .wy-alert-title,.rst-content .wy-alert-warning.danger .wy-alert-title,.rst-content .wy-alert-warning.error .wy-alert-title,.rst-content .wy-alert-warning.hint .wy-alert-title,.rst-content .wy-alert-warning.important .wy-alert-title,.rst-content .wy-alert-warning.tip .wy-alert-title,.rst-content .warning .wy-alert-title,.rst-content .wy-alert-warning.seealso .wy-alert-title,.rst-content .admonition-todo .wy-alert-title,.rst-content .wy-alert-warning.admonition .wy-alert-title,.wy-alert.wy-alert-warning .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-warning .admonition-title,.rst-content .wy-alert-warning.note .admonition-title,.rst-content .attention .admonition-title,.rst-content .caution .admonition-title,.rst-content .wy-alert-warning.danger .admonition-title,.rst-content .wy-alert-warning.error .admonition-title,.rst-content .wy-alert-warning.hint .admonition-title,.rst-content .wy-alert-warning.important .admonition-title,.rst-content .wy-alert-warning.tip .admonition-title,.rst-content .warning .admonition-title,.rst-content .wy-alert-warning.seealso .admonition-title,.rst-content .admonition-todo .admonition-title,.rst-content .wy-alert-warning.admonition .admonition-title{background:#f0b37e}.wy-alert.wy-alert-info,.rst-content .note,.rst-content .wy-alert-info.attention,.rst-content .wy-alert-info.caution,.rst-content .wy-alert-info.danger,.rst-content .wy-alert-info.error,.rst-content .wy-alert-info.hint,.rst-content .wy-alert-info.important,.rst-content .wy-alert-info.tip,.rst-content .wy-alert-info.warning,.rst-content .seealso,.rst-content .wy-alert-info.admonition-todo,.rst-content .wy-alert-info.admonition{background:#e7f2fa}.wy-alert.wy-alert-info .wy-alert-title,.rst-content .note .wy-alert-title,.rst-content .wy-alert-info.attention .wy-alert-title,.rst-content .wy-alert-info.caution .wy-alert-title,.rst-content .wy-alert-info.danger .wy-alert-title,.rst-content .wy-alert-info.error .wy-alert-title,.rst-content .wy-alert-info.hint .wy-alert-title,.rst-content .wy-alert-info.important .wy-alert-title,.rst-content .wy-alert-info.tip .wy-alert-title,.rst-content .wy-alert-info.warning .wy-alert-title,.rst-content .seealso .wy-alert-title,.rst-content .wy-alert-info.admonition-todo .wy-alert-title,.rst-content .wy-alert-info.admonition .wy-alert-title,.wy-alert.wy-alert-info .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-info .admonition-title,.rst-content .note .admonition-title,.rst-content .wy-alert-info.attention .admonition-title,.rst-content .wy-alert-info.caution .admonition-title,.rst-content .wy-alert-info.danger .admonition-title,.rst-content .wy-alert-info.error .admonition-title,.rst-content .wy-alert-info.hint .admonition-title,.rst-content .wy-alert-info.important .admonition-title,.rst-content .wy-alert-info.tip .admonition-title,.rst-content .wy-alert-info.warning .admonition-title,.rst-content .seealso .admonition-title,.rst-content .wy-alert-info.admonition-todo .admonition-title,.rst-content .wy-alert-info.admonition .admonition-title{background:#6ab0de}.wy-alert.wy-alert-success,.rst-content .wy-alert-success.note,.rst-content .wy-alert-success.attention,.rst-content .wy-alert-success.caution,.rst-content .wy-alert-success.danger,.rst-content .wy-alert-success.error,.rst-content .hint,.rst-content .important,.rst-content .tip,.rst-content .wy-alert-success.warning,.rst-content .wy-alert-success.seealso,.rst-content .wy-alert-success.admonition-todo,.rst-content .wy-alert-success.admonition{background:#dbfaf4}.wy-alert.wy-alert-success .wy-alert-title,.rst-content .wy-alert-success.note .wy-alert-title,.rst-content .wy-alert-success.attention .wy-alert-title,.rst-content .wy-alert-success.caution .wy-alert-title,.rst-content .wy-alert-success.danger .wy-alert-title,.rst-content .wy-alert-success.error .wy-alert-title,.rst-content .hint .wy-alert-title,.rst-content .important .wy-alert-title,.rst-content .tip .wy-alert-title,.rst-content .wy-alert-success.warning .wy-alert-title,.rst-content .wy-alert-success.seealso .wy-alert-title,.rst-content .wy-alert-success.admonition-todo .wy-alert-title,.rst-content .wy-alert-success.admonition .wy-alert-title,.wy-alert.wy-alert-success .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-success .admonition-title,.rst-content .wy-alert-success.note .admonition-title,.rst-content .wy-alert-success.attention .admonition-title,.rst-content .wy-alert-success.caution .admonition-title,.rst-content .wy-alert-success.danger .admonition-title,.rst-content .wy-alert-success.error .admonition-title,.rst-content .hint .admonition-title,.rst-content .important .admonition-title,.rst-content .tip .admonition-title,.rst-content .wy-alert-success.warning .admonition-title,.rst-content .wy-alert-success.seealso .admonition-title,.rst-content .wy-alert-success.admonition-todo .admonition-title,.rst-content .wy-alert-success.admonition .admonition-title{background:#1abc9c}.wy-alert.wy-alert-neutral,.rst-content .wy-alert-neutral.note,.rst-content .wy-alert-neutral.attention,.rst-content .wy-alert-neutral.caution,.rst-content .wy-alert-neutral.danger,.rst-content .wy-alert-neutral.error,.rst-content .wy-alert-neutral.hint,.rst-content .wy-alert-neutral.important,.rst-content .wy-alert-neutral.tip,.rst-content .wy-alert-neutral.warning,.rst-content .wy-alert-neutral.seealso,.rst-content .wy-alert-neutral.admonition-todo,.rst-content .wy-alert-neutral.admonition{background:#f3f6f6}.wy-alert.wy-alert-neutral .wy-alert-title,.rst-content .wy-alert-neutral.note .wy-alert-title,.rst-content .wy-alert-neutral.attention .wy-alert-title,.rst-content .wy-alert-neutral.caution .wy-alert-title,.rst-content .wy-alert-neutral.danger .wy-alert-title,.rst-content .wy-alert-neutral.error .wy-alert-title,.rst-content .wy-alert-neutral.hint .wy-alert-title,.rst-content .wy-alert-neutral.important .wy-alert-title,.rst-content .wy-alert-neutral.tip .wy-alert-title,.rst-content .wy-alert-neutral.warning .wy-alert-title,.rst-content .wy-alert-neutral.seealso .wy-alert-title,.rst-content .wy-alert-neutral.admonition-todo .wy-alert-title,.rst-content .wy-alert-neutral.admonition .wy-alert-title,.wy-alert.wy-alert-neutral .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-neutral .admonition-title,.rst-content .wy-alert-neutral.note .admonition-title,.rst-content .wy-alert-neutral.attention .admonition-title,.rst-content .wy-alert-neutral.caution .admonition-title,.rst-content .wy-alert-neutral.danger .admonition-title,.rst-content .wy-alert-neutral.error .admonition-title,.rst-content .wy-alert-neutral.hint .admonition-title,.rst-content .wy-alert-neutral.important .admonition-title,.rst-content .wy-alert-neutral.tip .admonition-title,.rst-content .wy-alert-neutral.warning .admonition-title,.rst-content .wy-alert-neutral.seealso .admonition-title,.rst-content .wy-alert-neutral.admonition-todo .admonition-title,.rst-content .wy-alert-neutral.admonition .admonition-title{color:#404040;background:#e1e4e5}.wy-alert.wy-alert-neutral a,.rst-content .wy-alert-neutral.note a,.rst-content .wy-alert-neutral.attention a,.rst-content .wy-alert-neutral.caution a,.rst-content .wy-alert-neutral.danger a,.rst-content .wy-alert-neutral.error a,.rst-content .wy-alert-neutral.hint a,.rst-content .wy-alert-neutral.important a,.rst-content .wy-alert-neutral.tip a,.rst-content .wy-alert-neutral.warning a,.rst-content .wy-alert-neutral.seealso a,.rst-content .wy-alert-neutral.admonition-todo a,.rst-content .wy-alert-neutral.admonition a{color:#2980B9}.wy-alert p:last-child,.rst-content .note p:last-child,.rst-content .attention p:last-child,.rst-content .caution p:last-child,.rst-content .danger p:last-child,.rst-content .error p:last-child,.rst-content .hint p:last-child,.rst-content .important p:last-child,.rst-content .tip p:last-child,.rst-content .warning p:last-child,.rst-content .seealso p:last-child,.rst-content .admonition-todo p:last-child,.rst-content .admonition p:last-child{margin-bottom:0}.wy-tray-container{position:fixed;bottom:0px;left:0;z-index:600}.wy-tray-container li{display:block;width:300px;background:transparent;color:#fff;text-align:center;box-shadow:0 5px 5px 0 rgba(0,0,0,0.1);padding:0 24px;min-width:20%;opacity:0;height:0;line-height:56px;overflow:hidden;-webkit-transition:all .3s ease-in;-moz-transition:all .3s ease-in;transition:all .3s ease-in}.wy-tray-container li.wy-tray-item-success{background:#27AE60}.wy-tray-container li.wy-tray-item-info{background:#2980B9}.wy-tray-container li.wy-tray-item-warning{background:#E67E22}.wy-tray-container li.wy-tray-item-danger{background:#E74C3C}.wy-tray-container li.on{opacity:1;height:56px}@media screen and (max-width: 768px){.wy-tray-container{bottom:auto;top:0;width:100%}.wy-tray-container li{width:100%}}button{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle;cursor:pointer;line-height:normal;-webkit-appearance:button;*overflow:visible}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}button[disabled]{cursor:default}.btn{display:inline-block;border-radius:2px;line-height:normal;white-space:nowrap;text-align:center;cursor:pointer;font-size:100%;padding:6px 12px 8px 12px;color:#fff;border:1px solid rgba(0,0,0,0.1);background-color:#27AE60;text-decoration:none;font-weight:normal;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;box-shadow:0px 1px 2px -1px rgba(255,255,255,0.5) inset,0px -2px 0px 0px rgba(0,0,0,0.1) inset;outline-none:false;vertical-align:middle;*display:inline;zoom:1;-webkit-user-drag:none;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;-webkit-transition:all .1s linear;-moz-transition:all .1s linear;transition:all .1s linear}.btn-hover{background:#2e8ece;color:#fff}.btn:hover{background:#2cc36b;color:#fff}.btn:focus{background:#2cc36b;outline:0}.btn:active{box-shadow:0px -1px 0px 0px rgba(0,0,0,0.05) inset,0px 2px 0px 0px rgba(0,0,0,0.1) inset;padding:8px 12px 6px 12px}.btn:visited{color:#fff}.btn:disabled{background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);filter:alpha(opacity=40);opacity:.4;cursor:not-allowed;box-shadow:none}.btn-disabled{background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);filter:alpha(opacity=40);opacity:.4;cursor:not-allowed;box-shadow:none}.btn-disabled:hover,.btn-disabled:focus,.btn-disabled:active{background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);filter:alpha(opacity=40);opacity:.4;cursor:not-allowed;box-shadow:none}.btn::-moz-focus-inner{padding:0;border:0}.btn-small{font-size:80%}.btn-info{background-color:#2980B9 !important}.btn-info:hover{background-color:#2e8ece !important}.btn-neutral{background-color:#f3f6f6 !important;color:#404040 !important}.btn-neutral:hover{background-color:#e5ebeb !important;color:#404040}.btn-neutral:visited{color:#404040 !important}.btn-success{background-color:#27AE60 !important}.btn-success:hover{background-color:#295 !important}.btn-danger{background-color:#E74C3C !important}.btn-danger:hover{background-color:#ea6153 !important}.btn-warning{background-color:#E67E22 !important}.btn-warning:hover{background-color:#e98b39 !important}.btn-invert{background-color:#222}.btn-invert:hover{background-color:#2f2f2f !important}.btn-link{background-color:transparent !important;color:#2980B9;box-shadow:none;border-color:transparent !important}.btn-link:hover{background-color:transparent !important;color:#409ad5 !important;box-shadow:none}.btn-link:active{background-color:transparent !important;color:#409ad5 !important;box-shadow:none}.btn-link:visited{color:#9B59B6}.wy-btn-group .btn,.wy-control .btn{vertical-align:middle}.wy-btn-group{margin-bottom:24px;*zoom:1}.wy-btn-group:before,.wy-btn-group:after{display:table;content:""}.wy-btn-group:after{clear:both}.wy-dropdown{position:relative;display:inline-block}.wy-dropdown-active .wy-dropdown-menu{display:block}.wy-dropdown-menu{position:absolute;left:0;display:none;float:left;top:100%;min-width:100%;background:#fcfcfc;z-index:100;border:solid 1px #cfd7dd;box-shadow:0 2px 2px 0 rgba(0,0,0,0.1);padding:12px}.wy-dropdown-menu>dd>a{display:block;clear:both;color:#404040;white-space:nowrap;font-size:90%;padding:0 12px;cursor:pointer}.wy-dropdown-menu>dd>a:hover{background:#2980B9;color:#fff}.wy-dropdown-menu>dd.divider{border-top:solid 1px #cfd7dd;margin:6px 0}.wy-dropdown-menu>dd.search{padding-bottom:12px}.wy-dropdown-menu>dd.search input[type="search"]{width:100%}.wy-dropdown-menu>dd.call-to-action{background:#e3e3e3;text-transform:uppercase;font-weight:500;font-size:80%}.wy-dropdown-menu>dd.call-to-action:hover{background:#e3e3e3}.wy-dropdown-menu>dd.call-to-action .btn{color:#fff}.wy-dropdown.wy-dropdown-up .wy-dropdown-menu{bottom:100%;top:auto;left:auto;right:0}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu{background:#fcfcfc;margin-top:2px}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu a{padding:6px 12px}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu a:hover{background:#2980B9;color:#fff}.wy-dropdown.wy-dropdown-left .wy-dropdown-menu{right:0;left:auto;text-align:right}.wy-dropdown-arrow:before{content:" ";border-bottom:5px solid #f5f5f5;border-left:5px solid transparent;border-right:5px solid transparent;position:absolute;display:block;top:-4px;left:50%;margin-left:-3px}.wy-dropdown-arrow.wy-dropdown-arrow-left:before{left:11px}.wy-form-stacked select{display:block}.wy-form-aligned input,.wy-form-aligned textarea,.wy-form-aligned select,.wy-form-aligned .wy-help-inline,.wy-form-aligned label{display:inline-block;*display:inline;*zoom:1;vertical-align:middle}.wy-form-aligned .wy-control-group>label{display:inline-block;vertical-align:middle;width:10em;margin:6px 12px 0 0;float:left}.wy-form-aligned .wy-control{float:left}.wy-form-aligned .wy-control label{display:block}.wy-form-aligned .wy-control select{margin-top:6px}fieldset{border:0;margin:0;padding:0}legend{display:block;width:100%;border:0;padding:0;white-space:normal;margin-bottom:24px;font-size:150%;*margin-left:-7px}label{display:block;margin:0 0 .3125em 0;color:#333;font-size:90%}input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}.wy-control-group{margin-bottom:24px;*zoom:1;max-width:68em;margin-left:auto;margin-right:auto;*zoom:1}.wy-control-group:before,.wy-control-group:after{display:table;content:""}.wy-control-group:after{clear:both}.wy-control-group:before,.wy-control-group:after{display:table;content:""}.wy-control-group:after{clear:both}.wy-control-group.wy-control-group-required>label:after{content:" *";color:#E74C3C}.wy-control-group .wy-form-full,.wy-control-group .wy-form-halves,.wy-control-group .wy-form-thirds{padding-bottom:12px}.wy-control-group .wy-form-full select,.wy-control-group .wy-form-halves select,.wy-control-group .wy-form-thirds select{width:100%}.wy-control-group .wy-form-full input[type="text"],.wy-control-group .wy-form-full input[type="password"],.wy-control-group .wy-form-full input[type="email"],.wy-control-group .wy-form-full input[type="url"],.wy-control-group .wy-form-full input[type="date"],.wy-control-group .wy-form-full input[type="month"],.wy-control-group .wy-form-full input[type="time"],.wy-control-group .wy-form-full input[type="datetime"],.wy-control-group .wy-form-full input[type="datetime-local"],.wy-control-group .wy-form-full input[type="week"],.wy-control-group .wy-form-full input[type="number"],.wy-control-group .wy-form-full input[type="search"],.wy-control-group .wy-form-full input[type="tel"],.wy-control-group .wy-form-full input[type="color"],.wy-control-group .wy-form-halves input[type="text"],.wy-control-group .wy-form-halves input[type="password"],.wy-control-group .wy-form-halves input[type="email"],.wy-control-group .wy-form-halves input[type="url"],.wy-control-group .wy-form-halves input[type="date"],.wy-control-group .wy-form-halves input[type="month"],.wy-control-group .wy-form-halves input[type="time"],.wy-control-group .wy-form-halves input[type="datetime"],.wy-control-group .wy-form-halves input[type="datetime-local"],.wy-control-group .wy-form-halves input[type="week"],.wy-control-group .wy-form-halves input[type="number"],.wy-control-group .wy-form-halves input[type="search"],.wy-control-group .wy-form-halves input[type="tel"],.wy-control-group .wy-form-halves input[type="color"],.wy-control-group .wy-form-thirds input[type="text"],.wy-control-group .wy-form-thirds input[type="password"],.wy-control-group .wy-form-thirds input[type="email"],.wy-control-group .wy-form-thirds input[type="url"],.wy-control-group .wy-form-thirds input[type="date"],.wy-control-group .wy-form-thirds input[type="month"],.wy-control-group .wy-form-thirds input[type="time"],.wy-control-group .wy-form-thirds input[type="datetime"],.wy-control-group .wy-form-thirds input[type="datetime-local"],.wy-control-group .wy-form-thirds input[type="week"],.wy-control-group .wy-form-thirds input[type="number"],.wy-control-group .wy-form-thirds input[type="search"],.wy-control-group .wy-form-thirds input[type="tel"],.wy-control-group .wy-form-thirds input[type="color"]{width:100%}.wy-control-group .wy-form-full{float:left;display:block;margin-right:2.3576515979%;width:100%;margin-right:0}.wy-control-group .wy-form-full:last-child{margin-right:0}.wy-control-group .wy-form-halves{float:left;display:block;margin-right:2.3576515979%;width:48.821174201%}.wy-control-group .wy-form-halves:last-child{margin-right:0}.wy-control-group .wy-form-halves:nth-of-type(2n){margin-right:0}.wy-control-group .wy-form-halves:nth-of-type(2n+1){clear:left}.wy-control-group .wy-form-thirds{float:left;display:block;margin-right:2.3576515979%;width:31.7615656014%}.wy-control-group .wy-form-thirds:last-child{margin-right:0}.wy-control-group .wy-form-thirds:nth-of-type(3n){margin-right:0}.wy-control-group .wy-form-thirds:nth-of-type(3n+1){clear:left}.wy-control-group.wy-control-group-no-input .wy-control{margin:6px 0 0 0;font-size:90%}.wy-control-no-input{display:inline-block;margin:6px 0 0 0;font-size:90%}.wy-control-group.fluid-input input[type="text"],.wy-control-group.fluid-input input[type="password"],.wy-control-group.fluid-input input[type="email"],.wy-control-group.fluid-input input[type="url"],.wy-control-group.fluid-input input[type="date"],.wy-control-group.fluid-input input[type="month"],.wy-control-group.fluid-input input[type="time"],.wy-control-group.fluid-input input[type="datetime"],.wy-control-group.fluid-input input[type="datetime-local"],.wy-control-group.fluid-input input[type="week"],.wy-control-group.fluid-input input[type="number"],.wy-control-group.fluid-input input[type="search"],.wy-control-group.fluid-input input[type="tel"],.wy-control-group.fluid-input input[type="color"]{width:100%}.wy-form-message-inline{display:inline-block;padding-left:.3em;color:#666;vertical-align:middle;font-size:90%}.wy-form-message{display:block;color:#999;font-size:70%;margin-top:.3125em;font-style:italic}.wy-form-message p{font-size:inherit;font-style:italic;margin-bottom:6px}.wy-form-message p:last-child{margin-bottom:0}input{line-height:normal}input[type="button"],input[type="reset"],input[type="submit"]{-webkit-appearance:button;cursor:pointer;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;*overflow:visible}input[type="text"],input[type="password"],input[type="email"],input[type="url"],input[type="date"],input[type="month"],input[type="time"],input[type="datetime"],input[type="datetime-local"],input[type="week"],input[type="number"],input[type="search"],input[type="tel"],input[type="color"]{-webkit-appearance:none;padding:6px;display:inline-block;border:1px solid #ccc;font-size:80%;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;box-shadow:inset 0 1px 3px #ddd;border-radius:0;-webkit-transition:border .3s linear;-moz-transition:border .3s linear;transition:border .3s linear}input[type="datetime-local"]{padding:.34375em .625em}input[disabled]{cursor:default}input[type="checkbox"],input[type="radio"]{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;padding:0;margin-right:.3125em;*height:13px;*width:13px}input[type="search"]{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}input[type="search"]::-webkit-search-cancel-button,input[type="search"]::-webkit-search-decoration{-webkit-appearance:none}input[type="text"]:focus,input[type="password"]:focus,input[type="email"]:focus,input[type="url"]:focus,input[type="date"]:focus,input[type="month"]:focus,input[type="time"]:focus,input[type="datetime"]:focus,input[type="datetime-local"]:focus,input[type="week"]:focus,input[type="number"]:focus,input[type="search"]:focus,input[type="tel"]:focus,input[type="color"]:focus{outline:0;outline:thin dotted \9;border-color:#333}input.no-focus:focus{border-color:#ccc !important}input[type="file"]:focus,input[type="radio"]:focus,input[type="checkbox"]:focus{outline:thin dotted #333;outline:1px auto #129FEA}input[type="text"][disabled],input[type="password"][disabled],input[type="email"][disabled],input[type="url"][disabled],input[type="date"][disabled],input[type="month"][disabled],input[type="time"][disabled],input[type="datetime"][disabled],input[type="datetime-local"][disabled],input[type="week"][disabled],input[type="number"][disabled],input[type="search"][disabled],input[type="tel"][disabled],input[type="color"][disabled]{cursor:not-allowed;background-color:#fafafa}input:focus:invalid,textarea:focus:invalid,select:focus:invalid{color:#E74C3C;border:1px solid #E74C3C}input:focus:invalid:focus,textarea:focus:invalid:focus,select:focus:invalid:focus{border-color:#E74C3C}input[type="file"]:focus:invalid:focus,input[type="radio"]:focus:invalid:focus,input[type="checkbox"]:focus:invalid:focus{outline-color:#E74C3C}input.wy-input-large{padding:12px;font-size:100%}textarea{overflow:auto;vertical-align:top;width:100%;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif}select,textarea{padding:.5em .625em;display:inline-block;border:1px solid #ccc;font-size:80%;box-shadow:inset 0 1px 3px #ddd;-webkit-transition:border .3s linear;-moz-transition:border .3s linear;transition:border .3s linear}select{border:1px solid #ccc;background-color:#fff}select[multiple]{height:auto}select:focus,textarea:focus{outline:0}select[disabled],textarea[disabled],input[readonly],select[readonly],textarea[readonly]{cursor:not-allowed;background-color:#fafafa}input[type="radio"][disabled],input[type="checkbox"][disabled]{cursor:not-allowed}.wy-checkbox,.wy-radio{margin:6px 0;color:#404040;display:block}.wy-checkbox input,.wy-radio input{vertical-align:baseline}.wy-form-message-inline{display:inline-block;*display:inline;*zoom:1;vertical-align:middle}.wy-input-prefix,.wy-input-suffix{white-space:nowrap;padding:6px}.wy-input-prefix .wy-input-context,.wy-input-suffix .wy-input-context{line-height:27px;padding:0 8px;display:inline-block;font-size:80%;background-color:#f3f6f6;border:solid 1px #ccc;color:#999}.wy-input-suffix .wy-input-context{border-left:0}.wy-input-prefix .wy-input-context{border-right:0}.wy-switch{position:relative;display:block;height:24px;margin-top:12px;cursor:pointer}.wy-switch:before{position:absolute;content:"";display:block;left:0;top:0;width:36px;height:12px;border-radius:4px;background:#ccc;-webkit-transition:all .2s ease-in-out;-moz-transition:all .2s ease-in-out;transition:all .2s ease-in-out}.wy-switch:after{position:absolute;content:"";display:block;width:18px;height:18px;border-radius:4px;background:#999;left:-3px;top:-3px;-webkit-transition:all .2s ease-in-out;-moz-transition:all .2s ease-in-out;transition:all .2s ease-in-out}.wy-switch span{position:absolute;left:48px;display:block;font-size:12px;color:#ccc;line-height:1}.wy-switch.active:before{background:#1e8449}.wy-switch.active:after{left:24px;background:#27AE60}.wy-switch.disabled{cursor:not-allowed;opacity:.8}.wy-control-group.wy-control-group-error .wy-form-message,.wy-control-group.wy-control-group-error>label{color:#E74C3C}.wy-control-group.wy-control-group-error input[type="text"],.wy-control-group.wy-control-group-error input[type="password"],.wy-control-group.wy-control-group-error input[type="email"],.wy-control-group.wy-control-group-error input[type="url"],.wy-control-group.wy-control-group-error input[type="date"],.wy-control-group.wy-control-group-error input[type="month"],.wy-control-group.wy-control-group-error input[type="time"],.wy-control-group.wy-control-group-error input[type="datetime"],.wy-control-group.wy-control-group-error input[type="datetime-local"],.wy-control-group.wy-control-group-error input[type="week"],.wy-control-group.wy-control-group-error input[type="number"],.wy-control-group.wy-control-group-error input[type="search"],.wy-control-group.wy-control-group-error input[type="tel"],.wy-control-group.wy-control-group-error input[type="color"]{border:solid 1px #E74C3C}.wy-control-group.wy-control-group-error textarea{border:solid 1px #E74C3C}.wy-inline-validate{white-space:nowrap}.wy-inline-validate .wy-input-context{padding:.5em .625em;display:inline-block;font-size:80%}.wy-inline-validate.wy-inline-validate-success .wy-input-context{color:#27AE60}.wy-inline-validate.wy-inline-validate-danger .wy-input-context{color:#E74C3C}.wy-inline-validate.wy-inline-validate-warning .wy-input-context{color:#E67E22}.wy-inline-validate.wy-inline-validate-info .wy-input-context{color:#2980B9}.rotate-90{-webkit-transform:rotate(90deg);-moz-transform:rotate(90deg);-ms-transform:rotate(90deg);-o-transform:rotate(90deg);transform:rotate(90deg)}.rotate-180{-webkit-transform:rotate(180deg);-moz-transform:rotate(180deg);-ms-transform:rotate(180deg);-o-transform:rotate(180deg);transform:rotate(180deg)}.rotate-270{-webkit-transform:rotate(270deg);-moz-transform:rotate(270deg);-ms-transform:rotate(270deg);-o-transform:rotate(270deg);transform:rotate(270deg)}.mirror{-webkit-transform:scaleX(-1);-moz-transform:scaleX(-1);-ms-transform:scaleX(-1);-o-transform:scaleX(-1);transform:scaleX(-1)}.mirror.rotate-90{-webkit-transform:scaleX(-1) rotate(90deg);-moz-transform:scaleX(-1) rotate(90deg);-ms-transform:scaleX(-1) rotate(90deg);-o-transform:scaleX(-1) rotate(90deg);transform:scaleX(-1) rotate(90deg)}.mirror.rotate-180{-webkit-transform:scaleX(-1) rotate(180deg);-moz-transform:scaleX(-1) rotate(180deg);-ms-transform:scaleX(-1) rotate(180deg);-o-transform:scaleX(-1) rotate(180deg);transform:scaleX(-1) rotate(180deg)}.mirror.rotate-270{-webkit-transform:scaleX(-1) rotate(270deg);-moz-transform:scaleX(-1) rotate(270deg);-ms-transform:scaleX(-1) rotate(270deg);-o-transform:scaleX(-1) rotate(270deg);transform:scaleX(-1) rotate(270deg)}@media only screen and (max-width: 480px){.wy-form button[type="submit"]{margin:.7em 0 0}.wy-form input[type="text"],.wy-form input[type="password"],.wy-form input[type="email"],.wy-form input[type="url"],.wy-form input[type="date"],.wy-form input[type="month"],.wy-form input[type="time"],.wy-form input[type="datetime"],.wy-form input[type="datetime-local"],.wy-form input[type="week"],.wy-form input[type="number"],.wy-form input[type="search"],.wy-form input[type="tel"],.wy-form input[type="color"]{margin-bottom:.3em;display:block}.wy-form label{margin-bottom:.3em;display:block}.wy-form input[type="password"],.wy-form input[type="email"],.wy-form input[type="url"],.wy-form input[type="date"],.wy-form input[type="month"],.wy-form input[type="time"],.wy-form input[type="datetime"],.wy-form input[type="datetime-local"],.wy-form input[type="week"],.wy-form input[type="number"],.wy-form input[type="search"],.wy-form input[type="tel"],.wy-form input[type="color"]{margin-bottom:0}.wy-form-aligned .wy-control-group label{margin-bottom:.3em;text-align:left;display:block;width:100%}.wy-form-aligned .wy-control{margin:1.5em 0 0 0}.wy-form .wy-help-inline,.wy-form-message-inline,.wy-form-message{display:block;font-size:80%;padding:6px 0}}@media screen and (max-width: 768px){.tablet-hide{display:none}}@media screen and (max-width: 480px){.mobile-hide{display:none}}.float-left{float:left}.float-right{float:right}.full-width{width:100%}.wy-table,.rst-content table.docutils,.rst-content table.field-list{border-collapse:collapse;border-spacing:0;empty-cells:show;margin-bottom:24px}.wy-table caption,.rst-content table.docutils caption,.rst-content table.field-list caption{color:#000;font:italic 85%/1 arial,sans-serif;padding:1em 0;text-align:center}.wy-table td,.rst-content table.docutils td,.rst-content table.field-list td,.wy-table th,.rst-content table.docutils th,.rst-content table.field-list th{font-size:90%;margin:0;overflow:visible;padding:8px 16px}.wy-table td:first-child,.rst-content table.docutils td:first-child,.rst-content table.field-list td:first-child,.wy-table th:first-child,.rst-content table.docutils th:first-child,.rst-content table.field-list th:first-child{border-left-width:0}.wy-table thead,.rst-content table.docutils thead,.rst-content table.field-list thead{color:#000;text-align:left;vertical-align:bottom;white-space:nowrap}.wy-table thead th,.rst-content table.docutils thead th,.rst-content table.field-list thead th{font-weight:bold;border-bottom:solid 2px #e1e4e5}.wy-table td,.rst-content table.docutils td,.rst-content table.field-list td{background-color:transparent;vertical-align:middle}.wy-table td p,.rst-content table.docutils td p,.rst-content table.field-list td p{line-height:18px}.wy-table td p:last-child,.rst-content table.docutils td p:last-child,.rst-content table.field-list td p:last-child{margin-bottom:0}.wy-table .wy-table-cell-min,.rst-content table.docutils .wy-table-cell-min,.rst-content table.field-list .wy-table-cell-min{width:1%;padding-right:0}.wy-table .wy-table-cell-min input[type=checkbox],.rst-content table.docutils .wy-table-cell-min input[type=checkbox],.rst-content table.field-list .wy-table-cell-min input[type=checkbox],.wy-table .wy-table-cell-min input[type=checkbox],.rst-content table.docutils .wy-table-cell-min input[type=checkbox],.rst-content table.field-list .wy-table-cell-min input[type=checkbox]{margin:0}.wy-table-secondary{color:gray;font-size:90%}.wy-table-tertiary{color:gray;font-size:80%}.wy-table-odd td,.wy-table-striped tr:nth-child(2n-1) td,.rst-content table.docutils:not(.field-list) tr:nth-child(2n-1) td{background-color:#f3f6f6}.wy-table-backed{background-color:#f3f6f6}.wy-table-bordered-all,.rst-content table.docutils{border:1px solid #e1e4e5}.wy-table-bordered-all td,.rst-content table.docutils td{border-bottom:1px solid #e1e4e5;border-left:1px solid #e1e4e5}.wy-table-bordered-all tbody>tr:last-child td,.rst-content table.docutils tbody>tr:last-child td{border-bottom-width:0}.wy-table-bordered{border:1px solid #e1e4e5}.wy-table-bordered-rows td{border-bottom:1px solid #e1e4e5}.wy-table-bordered-rows tbody>tr:last-child td{border-bottom-width:0}.wy-table-horizontal tbody>tr:last-child td{border-bottom-width:0}.wy-table-horizontal td,.wy-table-horizontal th{border-width:0 0 1px 0;border-bottom:1px solid #e1e4e5}.wy-table-horizontal tbody>tr:last-child td{border-bottom-width:0}.wy-table-responsive{margin-bottom:24px;max-width:100%;overflow:auto}.wy-table-responsive table{margin-bottom:0 !important}.wy-table-responsive table td,.wy-table-responsive table th{white-space:nowrap}a{color:#2980B9;text-decoration:none;cursor:pointer}a:hover{color:#3091d1}a:visited{color:#9B59B6}html{height:100%;overflow-x:hidden}body{font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;font-weight:normal;color:#404040;min-height:100%;overflow-x:hidden;background:#edf0f2}.wy-text-left{text-align:left}.wy-text-center{text-align:center}.wy-text-right{text-align:right}.wy-text-large{font-size:120%}.wy-text-normal{font-size:100%}.wy-text-small,small{font-size:80%}.wy-text-strike{text-decoration:line-through}.wy-text-warning{color:#E67E22 !important}a.wy-text-warning:hover{color:#eb9950 !important}.wy-text-info{color:#2980B9 !important}a.wy-text-info:hover{color:#409ad5 !important}.wy-text-success{color:#27AE60 !important}a.wy-text-success:hover{color:#36d278 !important}.wy-text-danger{color:#E74C3C !important}a.wy-text-danger:hover{color:#ed7669 !important}.wy-text-neutral{color:#404040 !important}a.wy-text-neutral:hover{color:#595959 !important}h1,h2,.rst-content .toctree-wrapper p.caption,h3,h4,h5,h6,legend{margin-top:0;font-weight:700;font-family:"Roboto Slab","ff-tisa-web-pro","Georgia",Arial,sans-serif}p{line-height:24px;margin:0;font-size:16px;margin-bottom:24px}h1{font-size:175%}h2,.rst-content .toctree-wrapper p.caption{font-size:150%}h3{font-size:125%}h4{font-size:115%}h5{font-size:110%}h6{font-size:100%}hr{display:block;height:1px;border:0;border-top:1px solid #e1e4e5;margin:24px 0;padding:0}code,.rst-content tt,.rst-content code{white-space:nowrap;max-width:100%;background:#fff;border:solid 1px #e1e4e5;font-size:75%;padding:0 5px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;color:#E74C3C;overflow-x:auto}code.code-large,.rst-content tt.code-large{font-size:90%}.wy-plain-list-disc,.rst-content .section ul,.rst-content .toctree-wrapper ul,article ul{list-style:disc;line-height:24px;margin-bottom:24px}.wy-plain-list-disc li,.rst-content .section ul li,.rst-content .toctree-wrapper ul li,article ul li{list-style:disc;margin-left:24px}.wy-plain-list-disc li p:last-child,.rst-content .section ul li p:last-child,.rst-content .toctree-wrapper ul li p:last-child,article ul li p:last-child{margin-bottom:0}.wy-plain-list-disc li ul,.rst-content .section ul li ul,.rst-content .toctree-wrapper ul li ul,article ul li ul{margin-bottom:0}.wy-plain-list-disc li li,.rst-content .section ul li li,.rst-content .toctree-wrapper ul li li,article ul li li{list-style:circle}.wy-plain-list-disc li li li,.rst-content .section ul li li li,.rst-content .toctree-wrapper ul li li li,article ul li li li{list-style:square}.wy-plain-list-disc li ol li,.rst-content .section ul li ol li,.rst-content .toctree-wrapper ul li ol li,article ul li ol li{list-style:decimal}.wy-plain-list-decimal,.rst-content .section ol,.rst-content ol.arabic,article ol{list-style:decimal;line-height:24px;margin-bottom:24px}.wy-plain-list-decimal li,.rst-content .section ol li,.rst-content ol.arabic li,article ol li{list-style:decimal;margin-left:24px}.wy-plain-list-decimal li p:last-child,.rst-content .section ol li p:last-child,.rst-content ol.arabic li p:last-child,article ol li p:last-child{margin-bottom:0}.wy-plain-list-decimal li ul,.rst-content .section ol li ul,.rst-content ol.arabic li ul,article ol li ul{margin-bottom:0}.wy-plain-list-decimal li ul li,.rst-content .section ol li ul li,.rst-content ol.arabic li ul li,article ol li ul li{list-style:disc}.wy-breadcrumbs{*zoom:1}.wy-breadcrumbs:before,.wy-breadcrumbs:after{display:table;content:""}.wy-breadcrumbs:after{clear:both}.wy-breadcrumbs li{display:inline-block}.wy-breadcrumbs li.wy-breadcrumbs-aside{float:right}.wy-breadcrumbs li a{display:inline-block;padding:5px}.wy-breadcrumbs li a:first-child{padding-left:0}.wy-breadcrumbs li code,.wy-breadcrumbs li .rst-content tt,.rst-content .wy-breadcrumbs li tt{padding:5px;border:none;background:none}.wy-breadcrumbs li code.literal,.wy-breadcrumbs li .rst-content tt.literal,.rst-content .wy-breadcrumbs li tt.literal{color:#404040}.wy-breadcrumbs-extra{margin-bottom:0;color:#b3b3b3;font-size:80%;display:inline-block}@media screen and (max-width: 480px){.wy-breadcrumbs-extra{display:none}.wy-breadcrumbs li.wy-breadcrumbs-aside{display:none}}@media print{.wy-breadcrumbs li.wy-breadcrumbs-aside{display:none}}html{font-size:16px}.wy-affix{position:fixed;top:1.618em}.wy-menu a:hover{text-decoration:none}.wy-menu-horiz{*zoom:1}.wy-menu-horiz:before,.wy-menu-horiz:after{display:table;content:""}.wy-menu-horiz:after{clear:both}.wy-menu-horiz ul,.wy-menu-horiz li{display:inline-block}.wy-menu-horiz li:hover{background:rgba(255,255,255,0.1)}.wy-menu-horiz li.divide-left{border-left:solid 1px #404040}.wy-menu-horiz li.divide-right{border-right:solid 1px #404040}.wy-menu-horiz a{height:32px;display:inline-block;line-height:32px;padding:0 16px}.wy-menu-vertical{width:300px}.wy-menu-vertical header,.wy-menu-vertical p.caption{height:32px;display:inline-block;line-height:32px;padding:0 1.618em;margin-bottom:0;display:block;font-weight:bold;text-transform:uppercase;font-size:80%;white-space:nowrap}.wy-menu-vertical ul{margin-bottom:0}.wy-menu-vertical li.divide-top{border-top:solid 1px #404040}.wy-menu-vertical li.divide-bottom{border-bottom:solid 1px #404040}.wy-menu-vertical li.current{background:#e3e3e3}.wy-menu-vertical li.current a{color:gray;border-right:solid 1px #c9c9c9;padding:.4045em 2.427em}.wy-menu-vertical li.current a:hover{background:#d6d6d6}.wy-menu-vertical li code,.wy-menu-vertical li .rst-content tt,.rst-content .wy-menu-vertical li tt{border:none;background:inherit;color:inherit;padding-left:0;padding-right:0}.wy-menu-vertical li span.toctree-expand{display:block;float:left;margin-left:-1.2em;font-size:.8em;line-height:1.6em;color:#4d4d4d}.wy-menu-vertical li.on a,.wy-menu-vertical li.current>a{color:#404040;padding:.4045em 1.618em;font-weight:bold;position:relative;background:#fcfcfc;border:none;padding-left:1.618em -4px}.wy-menu-vertical li.on a:hover,.wy-menu-vertical li.current>a:hover{background:#fcfcfc}.wy-menu-vertical li.on a:hover span.toctree-expand,.wy-menu-vertical li.current>a:hover span.toctree-expand{color:gray}.wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.current>a span.toctree-expand{display:block;font-size:.8em;line-height:1.6em;color:#333}.wy-menu-vertical li.toctree-l1.current>a{border-bottom:solid 1px #c9c9c9;border-top:solid 1px #c9c9c9}.wy-menu-vertical li.toctree-l2 a,.wy-menu-vertical li.toctree-l3 a,.wy-menu-vertical li.toctree-l4 a{color:#404040}.wy-menu-vertical li.toctree-l1.current li.toctree-l2>ul,.wy-menu-vertical li.toctree-l2.current li.toctree-l3>ul{display:none}.wy-menu-vertical li.toctree-l1.current li.toctree-l2.current>ul,.wy-menu-vertical li.toctree-l2.current li.toctree-l3.current>ul{display:block}.wy-menu-vertical li.toctree-l2.current>a{background:#c9c9c9;padding:.4045em 2.427em}.wy-menu-vertical li.toctree-l2.current li.toctree-l3>a{display:block;background:#c9c9c9;padding:.4045em 4.045em}.wy-menu-vertical li.toctree-l2 a:hover span.toctree-expand{color:gray}.wy-menu-vertical li.toctree-l2 span.toctree-expand{color:#a3a3a3}.wy-menu-vertical li.toctree-l3{font-size:.9em}.wy-menu-vertical li.toctree-l3.current>a{background:#bdbdbd;padding:.4045em 4.045em}.wy-menu-vertical li.toctree-l3.current li.toctree-l4>a{display:block;background:#bdbdbd;padding:.4045em 5.663em}.wy-menu-vertical li.toctree-l3 a:hover span.toctree-expand{color:gray}.wy-menu-vertical li.toctree-l3 span.toctree-expand{color:#969696}.wy-menu-vertical li.toctree-l4{font-size:.9em}.wy-menu-vertical li.current ul{display:block}.wy-menu-vertical li ul{margin-bottom:0;display:none}.wy-menu-vertical li ul li a{margin-bottom:0;color:#d9d9d9;font-weight:normal}.wy-menu-vertical a{display:inline-block;line-height:18px;padding:.4045em 1.618em;display:block;position:relative;font-size:90%;color:#d9d9d9}.wy-menu-vertical a:hover{background-color:#4e4a4a;cursor:pointer}.wy-menu-vertical a:hover span.toctree-expand{color:#d9d9d9}.wy-menu-vertical a:active{background-color:#2980B9;cursor:pointer;color:#fff}.wy-menu-vertical a:active span.toctree-expand{color:#fff}.wy-side-nav-search{display:block;width:300px;padding:.809em;margin-bottom:.809em;z-index:200;background-color:#2980B9;text-align:center;padding:.809em;display:block;color:#fcfcfc;margin-bottom:.809em}.wy-side-nav-search input[type=text]{width:100%;border-radius:50px;padding:6px 12px;border-color:#2472a4}.wy-side-nav-search img{display:block;margin:auto auto .809em auto;height:45px;width:45px;background-color:#2980B9;padding:5px;border-radius:100%}.wy-side-nav-search>a,.wy-side-nav-search .wy-dropdown>a{color:#fcfcfc;font-size:100%;font-weight:bold;display:inline-block;padding:4px 6px;margin-bottom:.809em}.wy-side-nav-search>a:hover,.wy-side-nav-search .wy-dropdown>a:hover{background:rgba(255,255,255,0.1)}.wy-side-nav-search>a img.logo,.wy-side-nav-search .wy-dropdown>a img.logo{display:block;margin:0 auto;height:auto;width:auto;border-radius:0;max-width:100%;background:transparent}.wy-side-nav-search>a.icon img.logo,.wy-side-nav-search .wy-dropdown>a.icon img.logo{margin-top:.85em}.wy-side-nav-search>div.version{margin-top:-.4045em;margin-bottom:.809em;font-weight:normal;color:rgba(255,255,255,0.3)}.wy-nav .wy-menu-vertical header{color:#2980B9}.wy-nav .wy-menu-vertical a{color:#b3b3b3}.wy-nav .wy-menu-vertical a:hover{background-color:#2980B9;color:#fff}[data-menu-wrap]{-webkit-transition:all .2s ease-in;-moz-transition:all .2s ease-in;transition:all .2s ease-in;position:absolute;opacity:1;width:100%;opacity:0}[data-menu-wrap].move-center{left:0;right:auto;opacity:1}[data-menu-wrap].move-left{right:auto;left:-100%;opacity:0}[data-menu-wrap].move-right{right:-100%;left:auto;opacity:0}.wy-body-for-nav{background:#fcfcfc}.wy-grid-for-nav{position:absolute;width:100%;height:100%}.wy-nav-side{position:fixed;top:0;bottom:0;left:0;padding-bottom:2em;width:300px;overflow-x:hidden;overflow-y:hidden;min-height:100%;color:#9b9b9b;background:#343131;z-index:200}.wy-side-scroll{width:320px;position:relative;overflow-x:hidden;overflow-y:scroll;height:100%}.wy-nav-top{display:none;background:#2980B9;color:#fff;padding:.4045em .809em;position:relative;line-height:50px;text-align:center;font-size:100%;*zoom:1}.wy-nav-top:before,.wy-nav-top:after{display:table;content:""}.wy-nav-top:after{clear:both}.wy-nav-top a{color:#fff;font-weight:bold}.wy-nav-top img{margin-right:12px;height:45px;width:45px;background-color:#2980B9;padding:5px;border-radius:100%}.wy-nav-top i{font-size:30px;float:left;cursor:pointer;padding-top:inherit}.wy-nav-content-wrap{margin-left:300px;background:#fcfcfc;min-height:100%}.wy-nav-content{padding:1.618em 3.236em;height:100%;max-width:800px;margin:auto}.wy-body-mask{position:fixed;width:100%;height:100%;background:rgba(0,0,0,0.2);display:none;z-index:499}.wy-body-mask.on{display:block}footer{color:gray}footer p{margin-bottom:12px}footer span.commit code,footer span.commit .rst-content tt,.rst-content footer span.commit tt{padding:0px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;font-size:1em;background:none;border:none;color:gray}.rst-footer-buttons{*zoom:1}.rst-footer-buttons:before,.rst-footer-buttons:after{width:100%}.rst-footer-buttons:before,.rst-footer-buttons:after{display:table;content:""}.rst-footer-buttons:after{clear:both}.rst-breadcrumbs-buttons{margin-top:12px;*zoom:1}.rst-breadcrumbs-buttons:before,.rst-breadcrumbs-buttons:after{display:table;content:""}.rst-breadcrumbs-buttons:after{clear:both}#search-results .search li{margin-bottom:24px;border-bottom:solid 1px #e1e4e5;padding-bottom:24px}#search-results .search li:first-child{border-top:solid 1px #e1e4e5;padding-top:24px}#search-results .search li a{font-size:120%;margin-bottom:12px;display:inline-block}#search-results .context{color:gray;font-size:90%}@media screen and (max-width: 768px){.wy-body-for-nav{background:#fcfcfc}.wy-nav-top{display:block}.wy-nav-side{left:-300px}.wy-nav-side.shift{width:85%;left:0}.wy-side-scroll{width:auto}.wy-side-nav-search{width:auto}.wy-menu.wy-menu-vertical{width:auto}.wy-nav-content-wrap{margin-left:0}.wy-nav-content-wrap .wy-nav-content{padding:1.618em}.wy-nav-content-wrap.shift{position:fixed;min-width:100%;left:85%;top:0;height:100%;overflow:hidden}}@media screen and (min-width: 1100px){.wy-nav-content-wrap{background:rgba(0,0,0,0.05)}.wy-nav-content{margin:0;background:#fcfcfc}}@media print{.rst-versions,footer,.wy-nav-side{display:none}.wy-nav-content-wrap{margin-left:0}}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa,.rst-versions .rst-current-version .wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li .rst-versions .rst-current-version span.toctree-expand,.rst-versions .rst-current-version .rst-content .admonition-title,.rst-content .rst-versions .rst-current-version .admonition-title,.rst-versions .rst-current-version .rst-content h1 .headerlink,.rst-content h1 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h2 .headerlink,.rst-content h2 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h3 .headerlink,.rst-content h3 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h4 .headerlink,.rst-content h4 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h5 .headerlink,.rst-content h5 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h6 .headerlink,.rst-content h6 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content dl dt .headerlink,.rst-content dl dt .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content p.caption .headerlink,.rst-content p.caption .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content table>caption .headerlink,.rst-content table>caption .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content tt.download span:first-child,.rst-content tt.download .rst-versions .rst-current-version span:first-child,.rst-versions .rst-current-version .rst-content code.download span:first-child,.rst-content code.download .rst-versions .rst-current-version span:first-child,.rst-versions .rst-current-version .icon{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}}.rst-content img{max-width:100%;height:auto}.rst-content div.figure{margin-bottom:24px}.rst-content div.figure p.caption{font-style:italic}.rst-content div.figure p:last-child.caption{margin-bottom:0px}.rst-content div.figure.align-center{text-align:center}.rst-content .section>img,.rst-content .section>a>img{margin-bottom:24px}.rst-content abbr[title]{text-decoration:none}.rst-content.style-external-links a.reference.external:after{font-family:FontAwesome;content:"";color:#b3b3b3;vertical-align:super;font-size:60%;margin:0 .2em}.rst-content blockquote{margin-left:24px;line-height:24px;margin-bottom:24px}.rst-content pre.literal-block{white-space:pre;margin:0;padding:12px 12px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;display:block;overflow:auto}.rst-content pre.literal-block,.rst-content div[class^='highlight']{border:1px solid #e1e4e5;overflow-x:auto;margin:1px 0 24px 0}.rst-content pre.literal-block div[class^='highlight'],.rst-content div[class^='highlight'] div[class^='highlight']{padding:0px;border:none;margin:0}.rst-content div[class^='highlight'] td.code{width:100%}.rst-content .linenodiv pre{border-right:solid 1px #e6e9ea;margin:0;padding:12px 12px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;user-select:none;pointer-events:none}.rst-content div[class^='highlight'] pre{white-space:pre;margin:0;padding:12px 12px;display:block;overflow:auto}.rst-content div[class^='highlight'] pre .hll{display:block;margin:0 -12px;padding:0 12px}.rst-content pre.literal-block,.rst-content div[class^='highlight'] pre,.rst-content .linenodiv pre{font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;font-size:12px;line-height:1.4}@media print{.rst-content .codeblock,.rst-content div[class^='highlight'],.rst-content div[class^='highlight'] pre{white-space:pre-wrap}}.rst-content .note .last,.rst-content .attention .last,.rst-content .caution .last,.rst-content .danger .last,.rst-content .error .last,.rst-content .hint .last,.rst-content .important .last,.rst-content .tip .last,.rst-content .warning .last,.rst-content .seealso .last,.rst-content .admonition-todo .last,.rst-content .admonition .last{margin-bottom:0}.rst-content .admonition-title:before{margin-right:4px}.rst-content .admonition table{border-color:rgba(0,0,0,0.1)}.rst-content .admonition table td,.rst-content .admonition table th{background:transparent !important;border-color:rgba(0,0,0,0.1) !important}.rst-content .section ol.loweralpha,.rst-content .section ol.loweralpha li{list-style:lower-alpha}.rst-content .section ol.upperalpha,.rst-content .section ol.upperalpha li{list-style:upper-alpha}.rst-content .section ol p,.rst-content .section ul p{margin-bottom:12px}.rst-content .section ol p:last-child,.rst-content .section ul p:last-child{margin-bottom:24px}.rst-content .line-block{margin-left:0px;margin-bottom:24px;line-height:24px}.rst-content .line-block .line-block{margin-left:24px;margin-bottom:0px}.rst-content .topic-title{font-weight:bold;margin-bottom:12px}.rst-content .toc-backref{color:#404040}.rst-content .align-right{float:right;margin:0px 0px 24px 24px}.rst-content .align-left{float:left;margin:0px 24px 24px 0px}.rst-content .align-center{margin:auto}.rst-content .align-center:not(table){display:block}.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content .toctree-wrapper p.caption .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content dl dt .headerlink,.rst-content p.caption .headerlink,.rst-content table>caption .headerlink{visibility:hidden;font-size:14px}.rst-content h1 .headerlink:after,.rst-content h2 .headerlink:after,.rst-content .toctree-wrapper p.caption .headerlink:after,.rst-content h3 .headerlink:after,.rst-content h4 .headerlink:after,.rst-content h5 .headerlink:after,.rst-content h6 .headerlink:after,.rst-content dl dt .headerlink:after,.rst-content p.caption .headerlink:after,.rst-content table>caption .headerlink:after{content:"";font-family:FontAwesome}.rst-content h1:hover .headerlink:after,.rst-content h2:hover .headerlink:after,.rst-content .toctree-wrapper p.caption:hover .headerlink:after,.rst-content h3:hover .headerlink:after,.rst-content h4:hover .headerlink:after,.rst-content h5:hover .headerlink:after,.rst-content h6:hover .headerlink:after,.rst-content dl dt:hover .headerlink:after,.rst-content p.caption:hover .headerlink:after,.rst-content table>caption:hover .headerlink:after{visibility:visible}.rst-content table>caption .headerlink:after{font-size:12px}.rst-content .centered{text-align:center}.rst-content .sidebar{float:right;width:40%;display:block;margin:0 0 24px 24px;padding:24px;background:#f3f6f6;border:solid 1px #e1e4e5}.rst-content .sidebar p,.rst-content .sidebar ul,.rst-content .sidebar dl{font-size:90%}.rst-content .sidebar .last{margin-bottom:0}.rst-content .sidebar .sidebar-title{display:block;font-family:"Roboto Slab","ff-tisa-web-pro","Georgia",Arial,sans-serif;font-weight:bold;background:#e1e4e5;padding:6px 12px;margin:-24px;margin-bottom:24px;font-size:100%}.rst-content .highlighted{background:#F1C40F;display:inline-block;font-weight:bold;padding:0 6px}.rst-content .footnote-reference,.rst-content .citation-reference{vertical-align:baseline;position:relative;top:-0.4em;line-height:0;font-size:90%}.rst-content table.docutils.citation,.rst-content table.docutils.footnote{background:none;border:none;color:gray}.rst-content table.docutils.citation td,.rst-content table.docutils.citation tr,.rst-content table.docutils.footnote td,.rst-content table.docutils.footnote tr{border:none;background-color:transparent !important;white-space:normal}.rst-content table.docutils.citation td.label,.rst-content table.docutils.footnote td.label{padding-left:0;padding-right:0;vertical-align:top}.rst-content table.docutils.citation tt,.rst-content table.docutils.citation code,.rst-content table.docutils.footnote tt,.rst-content table.docutils.footnote code{color:#555}.rst-content .wy-table-responsive.citation,.rst-content .wy-table-responsive.footnote{margin-bottom:0}.rst-content .wy-table-responsive.citation+:not(.citation),.rst-content .wy-table-responsive.footnote+:not(.footnote){margin-top:24px}.rst-content .wy-table-responsive.citation:last-child,.rst-content .wy-table-responsive.footnote:last-child{margin-bottom:24px}.rst-content table.docutils th{border-color:#e1e4e5}.rst-content table.docutils td .last,.rst-content table.docutils td .last :last-child{margin-bottom:0}.rst-content table.field-list{border:none}.rst-content table.field-list td{border:none}.rst-content table.field-list td>strong{display:inline-block}.rst-content table.field-list .field-name{padding-right:10px;text-align:left;white-space:nowrap}.rst-content table.field-list .field-body{text-align:left}.rst-content tt,.rst-content tt,.rst-content code{color:#000;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;padding:2px 5px}.rst-content tt big,.rst-content tt em,.rst-content tt big,.rst-content code big,.rst-content tt em,.rst-content code em{font-size:100% !important;line-height:normal}.rst-content tt.literal,.rst-content tt.literal,.rst-content code.literal{color:#E74C3C}.rst-content tt.xref,a .rst-content tt,.rst-content tt.xref,.rst-content code.xref,a .rst-content tt,a .rst-content code{font-weight:bold;color:#404040}.rst-content pre,.rst-content kbd,.rst-content samp{font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace}.rst-content a tt,.rst-content a tt,.rst-content a code{color:#2980B9}.rst-content dl{margin-bottom:24px}.rst-content dl dt{font-weight:bold;margin-bottom:12px}.rst-content dl p,.rst-content dl table,.rst-content dl ul,.rst-content dl ol{margin-bottom:12px !important}.rst-content dl dd{margin:0 0 12px 24px;line-height:24px}.rst-content dl:not(.docutils){margin-bottom:24px}.rst-content dl:not(.docutils) dt{display:table;margin:6px 0;font-size:90%;line-height:normal;background:#e7f2fa;color:#2980B9;border-top:solid 3px #6ab0de;padding:6px;position:relative}.rst-content dl:not(.docutils) dt:before{color:#6ab0de}.rst-content dl:not(.docutils) dt .headerlink{color:#404040;font-size:100% !important}.rst-content dl:not(.docutils) dl dt{margin-bottom:6px;border:none;border-left:solid 3px #ccc;background:#f0f0f0;color:#555}.rst-content dl:not(.docutils) dl dt .headerlink{color:#404040;font-size:100% !important}.rst-content dl:not(.docutils) dt:first-child{margin-top:0}.rst-content dl:not(.docutils) tt,.rst-content dl:not(.docutils) tt,.rst-content dl:not(.docutils) code{font-weight:bold}.rst-content dl:not(.docutils) tt.descname,.rst-content dl:not(.docutils) tt.descclassname,.rst-content dl:not(.docutils) tt.descname,.rst-content dl:not(.docutils) code.descname,.rst-content dl:not(.docutils) tt.descclassname,.rst-content dl:not(.docutils) code.descclassname{background-color:transparent;border:none;padding:0;font-size:100% !important}.rst-content dl:not(.docutils) tt.descname,.rst-content dl:not(.docutils) tt.descname,.rst-content dl:not(.docutils) code.descname{font-weight:bold}.rst-content dl:not(.docutils) .optional{display:inline-block;padding:0 4px;color:#000;font-weight:bold}.rst-content dl:not(.docutils) .property{display:inline-block;padding-right:8px}.rst-content .viewcode-link,.rst-content .viewcode-back{display:inline-block;color:#27AE60;font-size:80%;padding-left:24px}.rst-content .viewcode-back{display:block;float:right}.rst-content p.rubric{margin-bottom:12px;font-weight:bold}.rst-content tt.download,.rst-content code.download{background:inherit;padding:inherit;font-weight:normal;font-family:inherit;font-size:inherit;color:inherit;border:inherit;white-space:inherit}.rst-content tt.download span:first-child,.rst-content code.download span:first-child{-webkit-font-smoothing:subpixel-antialiased}.rst-content tt.download span:first-child:before,.rst-content code.download span:first-child:before{margin-right:4px}.rst-content .guilabel{border:1px solid #7fbbe3;background:#e7f2fa;font-size:80%;font-weight:700;border-radius:4px;padding:2.4px 6px;margin:auto 2px}.rst-content .versionmodified{font-style:italic}@media screen and (max-width: 480px){.rst-content .sidebar{width:100%}}span[id*='MathJax-Span']{color:#404040}.math{text-align:center}@font-face{font-family:"Lato";src:url("../fonts/Lato/lato-regular.eot");src:url("../fonts/Lato/lato-regular.eot?#iefix") format("embedded-opentype"),url("../fonts/Lato/lato-regular.woff2") format("woff2"),url("../fonts/Lato/lato-regular.woff") format("woff"),url("../fonts/Lato/lato-regular.ttf") format("truetype");font-weight:400;font-style:normal}@font-face{font-family:"Lato";src:url("../fonts/Lato/lato-bold.eot");src:url("../fonts/Lato/lato-bold.eot?#iefix") format("embedded-opentype"),url("../fonts/Lato/lato-bold.woff2") format("woff2"),url("../fonts/Lato/lato-bold.woff") format("woff"),url("../fonts/Lato/lato-bold.ttf") format("truetype");font-weight:700;font-style:normal}@font-face{font-family:"Lato";src:url("../fonts/Lato/lato-bolditalic.eot");src:url("../fonts/Lato/lato-bolditalic.eot?#iefix") format("embedded-opentype"),url("../fonts/Lato/lato-bolditalic.woff2") format("woff2"),url("../fonts/Lato/lato-bolditalic.woff") format("woff"),url("../fonts/Lato/lato-bolditalic.ttf") format("truetype");font-weight:700;font-style:italic}@font-face{font-family:"Lato";src:url("../fonts/Lato/lato-italic.eot");src:url("../fonts/Lato/lato-italic.eot?#iefix") format("embedded-opentype"),url("../fonts/Lato/lato-italic.woff2") format("woff2"),url("../fonts/Lato/lato-italic.woff") format("woff"),url("../fonts/Lato/lato-italic.ttf") format("truetype");font-weight:400;font-style:italic}@font-face{font-family:"Roboto Slab";font-style:normal;font-weight:400;src:url("../fonts/RobotoSlab/roboto-slab.eot");src:url("../fonts/RobotoSlab/roboto-slab-v7-regular.eot?#iefix") format("embedded-opentype"),url("../fonts/RobotoSlab/roboto-slab-v7-regular.woff2") format("woff2"),url("../fonts/RobotoSlab/roboto-slab-v7-regular.woff") format("woff"),url("../fonts/RobotoSlab/roboto-slab-v7-regular.ttf") format("truetype")}@font-face{font-family:"Roboto Slab";font-style:normal;font-weight:700;src:url("../fonts/RobotoSlab/roboto-slab-v7-bold.eot");src:url("../fonts/RobotoSlab/roboto-slab-v7-bold.eot?#iefix") format("embedded-opentype"),url("../fonts/RobotoSlab/roboto-slab-v7-bold.woff2") format("woff2"),url("../fonts/RobotoSlab/roboto-slab-v7-bold.woff") format("woff"),url("../fonts/RobotoSlab/roboto-slab-v7-bold.ttf") format("truetype")} + */@font-face{font-family:'FontAwesome';src:url("../fonts/fontawesome-webfont.eot?v=4.7.0");src:url("../fonts/fontawesome-webfont.eot?#iefix&v=4.7.0") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff2?v=4.7.0") format("woff2"),url("../fonts/fontawesome-webfont.woff?v=4.7.0") format("woff"),url("../fonts/fontawesome-webfont.ttf?v=4.7.0") format("truetype"),url("../fonts/fontawesome-webfont.svg?v=4.7.0#fontawesomeregular") format("svg");font-weight:normal;font-style:normal}.fa,.wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.current>a span.toctree-expand,.rst-content .admonition-title,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content dl dt .headerlink,.rst-content p.caption .headerlink,.rst-content table>caption .headerlink,.rst-content .code-block-caption .headerlink,.rst-content tt.download span:first-child,.rst-content code.download span:first-child,.icon{display:inline-block;font:normal normal normal 14px/1 FontAwesome;font-size:inherit;text-rendering:auto;-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale}.fa-lg{font-size:1.3333333333em;line-height:.75em;vertical-align:-15%}.fa-2x{font-size:2em}.fa-3x{font-size:3em}.fa-4x{font-size:4em}.fa-5x{font-size:5em}.fa-fw{width:1.2857142857em;text-align:center}.fa-ul{padding-left:0;margin-left:2.1428571429em;list-style-type:none}.fa-ul>li{position:relative}.fa-li{position:absolute;left:-2.1428571429em;width:2.1428571429em;top:.1428571429em;text-align:center}.fa-li.fa-lg{left:-1.8571428571em}.fa-border{padding:.2em .25em .15em;border:solid 0.08em #eee;border-radius:.1em}.fa-pull-left{float:left}.fa-pull-right{float:right}.fa.fa-pull-left,.wy-menu-vertical li span.fa-pull-left.toctree-expand,.wy-menu-vertical li.on a span.fa-pull-left.toctree-expand,.wy-menu-vertical li.current>a span.fa-pull-left.toctree-expand,.rst-content .fa-pull-left.admonition-title,.rst-content h1 .fa-pull-left.headerlink,.rst-content h2 .fa-pull-left.headerlink,.rst-content h3 .fa-pull-left.headerlink,.rst-content h4 .fa-pull-left.headerlink,.rst-content h5 .fa-pull-left.headerlink,.rst-content h6 .fa-pull-left.headerlink,.rst-content dl dt .fa-pull-left.headerlink,.rst-content p.caption .fa-pull-left.headerlink,.rst-content table>caption .fa-pull-left.headerlink,.rst-content .code-block-caption .fa-pull-left.headerlink,.rst-content tt.download span.fa-pull-left:first-child,.rst-content code.download span.fa-pull-left:first-child,.fa-pull-left.icon{margin-right:.3em}.fa.fa-pull-right,.wy-menu-vertical li span.fa-pull-right.toctree-expand,.wy-menu-vertical li.on a span.fa-pull-right.toctree-expand,.wy-menu-vertical li.current>a span.fa-pull-right.toctree-expand,.rst-content .fa-pull-right.admonition-title,.rst-content h1 .fa-pull-right.headerlink,.rst-content h2 .fa-pull-right.headerlink,.rst-content h3 .fa-pull-right.headerlink,.rst-content h4 .fa-pull-right.headerlink,.rst-content h5 .fa-pull-right.headerlink,.rst-content h6 .fa-pull-right.headerlink,.rst-content dl dt .fa-pull-right.headerlink,.rst-content p.caption .fa-pull-right.headerlink,.rst-content table>caption .fa-pull-right.headerlink,.rst-content .code-block-caption .fa-pull-right.headerlink,.rst-content tt.download span.fa-pull-right:first-child,.rst-content code.download span.fa-pull-right:first-child,.fa-pull-right.icon{margin-left:.3em}.pull-right{float:right}.pull-left{float:left}.fa.pull-left,.wy-menu-vertical li span.pull-left.toctree-expand,.wy-menu-vertical li.on a span.pull-left.toctree-expand,.wy-menu-vertical li.current>a span.pull-left.toctree-expand,.rst-content .pull-left.admonition-title,.rst-content h1 .pull-left.headerlink,.rst-content h2 .pull-left.headerlink,.rst-content h3 .pull-left.headerlink,.rst-content h4 .pull-left.headerlink,.rst-content h5 .pull-left.headerlink,.rst-content h6 .pull-left.headerlink,.rst-content dl dt .pull-left.headerlink,.rst-content p.caption .pull-left.headerlink,.rst-content table>caption .pull-left.headerlink,.rst-content .code-block-caption .pull-left.headerlink,.rst-content tt.download span.pull-left:first-child,.rst-content code.download span.pull-left:first-child,.pull-left.icon{margin-right:.3em}.fa.pull-right,.wy-menu-vertical li span.pull-right.toctree-expand,.wy-menu-vertical li.on a span.pull-right.toctree-expand,.wy-menu-vertical li.current>a span.pull-right.toctree-expand,.rst-content .pull-right.admonition-title,.rst-content h1 .pull-right.headerlink,.rst-content h2 .pull-right.headerlink,.rst-content h3 .pull-right.headerlink,.rst-content h4 .pull-right.headerlink,.rst-content h5 .pull-right.headerlink,.rst-content h6 .pull-right.headerlink,.rst-content dl dt .pull-right.headerlink,.rst-content p.caption .pull-right.headerlink,.rst-content table>caption .pull-right.headerlink,.rst-content .code-block-caption .pull-right.headerlink,.rst-content tt.download span.pull-right:first-child,.rst-content code.download span.pull-right:first-child,.pull-right.icon{margin-left:.3em}.fa-spin{-webkit-animation:fa-spin 2s infinite linear;animation:fa-spin 2s infinite linear}.fa-pulse{-webkit-animation:fa-spin 1s infinite steps(8);animation:fa-spin 1s infinite steps(8)}@-webkit-keyframes fa-spin{0%{-webkit-transform:rotate(0deg);transform:rotate(0deg)}100%{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes fa-spin{0%{-webkit-transform:rotate(0deg);transform:rotate(0deg)}100%{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.fa-rotate-90{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=1)";-webkit-transform:rotate(90deg);-ms-transform:rotate(90deg);transform:rotate(90deg)}.fa-rotate-180{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=2)";-webkit-transform:rotate(180deg);-ms-transform:rotate(180deg);transform:rotate(180deg)}.fa-rotate-270{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=3)";-webkit-transform:rotate(270deg);-ms-transform:rotate(270deg);transform:rotate(270deg)}.fa-flip-horizontal{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=0, mirror=1)";-webkit-transform:scale(-1, 1);-ms-transform:scale(-1, 1);transform:scale(-1, 1)}.fa-flip-vertical{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=2, mirror=1)";-webkit-transform:scale(1, -1);-ms-transform:scale(1, -1);transform:scale(1, -1)}:root .fa-rotate-90,:root .fa-rotate-180,:root .fa-rotate-270,:root .fa-flip-horizontal,:root .fa-flip-vertical{filter:none}.fa-stack{position:relative;display:inline-block;width:2em;height:2em;line-height:2em;vertical-align:middle}.fa-stack-1x,.fa-stack-2x{position:absolute;left:0;width:100%;text-align:center}.fa-stack-1x{line-height:inherit}.fa-stack-2x{font-size:2em}.fa-inverse{color:#fff}.fa-glass:before{content:""}.fa-music:before{content:""}.fa-search:before,.icon-search:before{content:""}.fa-envelope-o:before{content:""}.fa-heart:before{content:""}.fa-star:before{content:""}.fa-star-o:before{content:""}.fa-user:before{content:""}.fa-film:before{content:""}.fa-th-large:before{content:""}.fa-th:before{content:""}.fa-th-list:before{content:""}.fa-check:before{content:""}.fa-remove:before,.fa-close:before,.fa-times:before{content:""}.fa-search-plus:before{content:""}.fa-search-minus:before{content:""}.fa-power-off:before{content:""}.fa-signal:before{content:""}.fa-gear:before,.fa-cog:before{content:""}.fa-trash-o:before{content:""}.fa-home:before,.icon-home:before{content:""}.fa-file-o:before{content:""}.fa-clock-o:before{content:""}.fa-road:before{content:""}.fa-download:before,.rst-content tt.download span:first-child:before,.rst-content code.download span:first-child:before{content:""}.fa-arrow-circle-o-down:before{content:""}.fa-arrow-circle-o-up:before{content:""}.fa-inbox:before{content:""}.fa-play-circle-o:before{content:""}.fa-rotate-right:before,.fa-repeat:before{content:""}.fa-refresh:before{content:""}.fa-list-alt:before{content:""}.fa-lock:before{content:""}.fa-flag:before{content:""}.fa-headphones:before{content:""}.fa-volume-off:before{content:""}.fa-volume-down:before{content:""}.fa-volume-up:before{content:""}.fa-qrcode:before{content:""}.fa-barcode:before{content:""}.fa-tag:before{content:""}.fa-tags:before{content:""}.fa-book:before,.icon-book:before{content:""}.fa-bookmark:before{content:""}.fa-print:before{content:""}.fa-camera:before{content:""}.fa-font:before{content:""}.fa-bold:before{content:""}.fa-italic:before{content:""}.fa-text-height:before{content:""}.fa-text-width:before{content:""}.fa-align-left:before{content:""}.fa-align-center:before{content:""}.fa-align-right:before{content:""}.fa-align-justify:before{content:""}.fa-list:before{content:""}.fa-dedent:before,.fa-outdent:before{content:""}.fa-indent:before{content:""}.fa-video-camera:before{content:""}.fa-photo:before,.fa-image:before,.fa-picture-o:before{content:""}.fa-pencil:before{content:""}.fa-map-marker:before{content:""}.fa-adjust:before{content:""}.fa-tint:before{content:""}.fa-edit:before,.fa-pencil-square-o:before{content:""}.fa-share-square-o:before{content:""}.fa-check-square-o:before{content:""}.fa-arrows:before{content:""}.fa-step-backward:before{content:""}.fa-fast-backward:before{content:""}.fa-backward:before{content:""}.fa-play:before{content:""}.fa-pause:before{content:""}.fa-stop:before{content:""}.fa-forward:before{content:""}.fa-fast-forward:before{content:""}.fa-step-forward:before{content:""}.fa-eject:before{content:""}.fa-chevron-left:before{content:""}.fa-chevron-right:before{content:""}.fa-plus-circle:before{content:""}.fa-minus-circle:before{content:""}.fa-times-circle:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before{content:""}.fa-check-circle:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before{content:""}.fa-question-circle:before{content:""}.fa-info-circle:before{content:""}.fa-crosshairs:before{content:""}.fa-times-circle-o:before{content:""}.fa-check-circle-o:before{content:""}.fa-ban:before{content:""}.fa-arrow-left:before{content:""}.fa-arrow-right:before{content:""}.fa-arrow-up:before{content:""}.fa-arrow-down:before{content:""}.fa-mail-forward:before,.fa-share:before{content:""}.fa-expand:before{content:""}.fa-compress:before{content:""}.fa-plus:before{content:""}.fa-minus:before{content:""}.fa-asterisk:before{content:""}.fa-exclamation-circle:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.rst-content .admonition-title:before{content:""}.fa-gift:before{content:""}.fa-leaf:before{content:""}.fa-fire:before,.icon-fire:before{content:""}.fa-eye:before{content:""}.fa-eye-slash:before{content:""}.fa-warning:before,.fa-exclamation-triangle:before{content:""}.fa-plane:before{content:""}.fa-calendar:before{content:""}.fa-random:before{content:""}.fa-comment:before{content:""}.fa-magnet:before{content:""}.fa-chevron-up:before{content:""}.fa-chevron-down:before{content:""}.fa-retweet:before{content:""}.fa-shopping-cart:before{content:""}.fa-folder:before{content:""}.fa-folder-open:before{content:""}.fa-arrows-v:before{content:""}.fa-arrows-h:before{content:""}.fa-bar-chart-o:before,.fa-bar-chart:before{content:""}.fa-twitter-square:before{content:""}.fa-facebook-square:before{content:""}.fa-camera-retro:before{content:""}.fa-key:before{content:""}.fa-gears:before,.fa-cogs:before{content:""}.fa-comments:before{content:""}.fa-thumbs-o-up:before{content:""}.fa-thumbs-o-down:before{content:""}.fa-star-half:before{content:""}.fa-heart-o:before{content:""}.fa-sign-out:before{content:""}.fa-linkedin-square:before{content:""}.fa-thumb-tack:before{content:""}.fa-external-link:before{content:""}.fa-sign-in:before{content:""}.fa-trophy:before{content:""}.fa-github-square:before{content:""}.fa-upload:before{content:""}.fa-lemon-o:before{content:""}.fa-phone:before{content:""}.fa-square-o:before{content:""}.fa-bookmark-o:before{content:""}.fa-phone-square:before{content:""}.fa-twitter:before{content:""}.fa-facebook-f:before,.fa-facebook:before{content:""}.fa-github:before,.icon-github:before{content:""}.fa-unlock:before{content:""}.fa-credit-card:before{content:""}.fa-feed:before,.fa-rss:before{content:""}.fa-hdd-o:before{content:""}.fa-bullhorn:before{content:""}.fa-bell:before{content:""}.fa-certificate:before{content:""}.fa-hand-o-right:before{content:""}.fa-hand-o-left:before{content:""}.fa-hand-o-up:before{content:""}.fa-hand-o-down:before{content:""}.fa-arrow-circle-left:before,.icon-circle-arrow-left:before{content:""}.fa-arrow-circle-right:before,.icon-circle-arrow-right:before{content:""}.fa-arrow-circle-up:before{content:""}.fa-arrow-circle-down:before{content:""}.fa-globe:before{content:""}.fa-wrench:before{content:""}.fa-tasks:before{content:""}.fa-filter:before{content:""}.fa-briefcase:before{content:""}.fa-arrows-alt:before{content:""}.fa-group:before,.fa-users:before{content:""}.fa-chain:before,.fa-link:before,.icon-link:before{content:""}.fa-cloud:before{content:""}.fa-flask:before{content:""}.fa-cut:before,.fa-scissors:before{content:""}.fa-copy:before,.fa-files-o:before{content:""}.fa-paperclip:before{content:""}.fa-save:before,.fa-floppy-o:before{content:""}.fa-square:before{content:""}.fa-navicon:before,.fa-reorder:before,.fa-bars:before{content:""}.fa-list-ul:before{content:""}.fa-list-ol:before{content:""}.fa-strikethrough:before{content:""}.fa-underline:before{content:""}.fa-table:before{content:""}.fa-magic:before{content:""}.fa-truck:before{content:""}.fa-pinterest:before{content:""}.fa-pinterest-square:before{content:""}.fa-google-plus-square:before{content:""}.fa-google-plus:before{content:""}.fa-money:before{content:""}.fa-caret-down:before,.wy-dropdown .caret:before,.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.fa-caret-left:before{content:""}.fa-caret-right:before{content:""}.fa-columns:before{content:""}.fa-unsorted:before,.fa-sort:before{content:""}.fa-sort-down:before,.fa-sort-desc:before{content:""}.fa-sort-up:before,.fa-sort-asc:before{content:""}.fa-envelope:before{content:""}.fa-linkedin:before{content:""}.fa-rotate-left:before,.fa-undo:before{content:""}.fa-legal:before,.fa-gavel:before{content:""}.fa-dashboard:before,.fa-tachometer:before{content:""}.fa-comment-o:before{content:""}.fa-comments-o:before{content:""}.fa-flash:before,.fa-bolt:before{content:""}.fa-sitemap:before{content:""}.fa-umbrella:before{content:""}.fa-paste:before,.fa-clipboard:before{content:""}.fa-lightbulb-o:before{content:""}.fa-exchange:before{content:""}.fa-cloud-download:before{content:""}.fa-cloud-upload:before{content:""}.fa-user-md:before{content:""}.fa-stethoscope:before{content:""}.fa-suitcase:before{content:""}.fa-bell-o:before{content:""}.fa-coffee:before{content:""}.fa-cutlery:before{content:""}.fa-file-text-o:before{content:""}.fa-building-o:before{content:""}.fa-hospital-o:before{content:""}.fa-ambulance:before{content:""}.fa-medkit:before{content:""}.fa-fighter-jet:before{content:""}.fa-beer:before{content:""}.fa-h-square:before{content:""}.fa-plus-square:before{content:""}.fa-angle-double-left:before{content:""}.fa-angle-double-right:before{content:""}.fa-angle-double-up:before{content:""}.fa-angle-double-down:before{content:""}.fa-angle-left:before{content:""}.fa-angle-right:before{content:""}.fa-angle-up:before{content:""}.fa-angle-down:before{content:""}.fa-desktop:before{content:""}.fa-laptop:before{content:""}.fa-tablet:before{content:""}.fa-mobile-phone:before,.fa-mobile:before{content:""}.fa-circle-o:before{content:""}.fa-quote-left:before{content:""}.fa-quote-right:before{content:""}.fa-spinner:before{content:""}.fa-circle:before{content:""}.fa-mail-reply:before,.fa-reply:before{content:""}.fa-github-alt:before{content:""}.fa-folder-o:before{content:""}.fa-folder-open-o:before{content:""}.fa-smile-o:before{content:""}.fa-frown-o:before{content:""}.fa-meh-o:before{content:""}.fa-gamepad:before{content:""}.fa-keyboard-o:before{content:""}.fa-flag-o:before{content:""}.fa-flag-checkered:before{content:""}.fa-terminal:before{content:""}.fa-code:before{content:""}.fa-mail-reply-all:before,.fa-reply-all:before{content:""}.fa-star-half-empty:before,.fa-star-half-full:before,.fa-star-half-o:before{content:""}.fa-location-arrow:before{content:""}.fa-crop:before{content:""}.fa-code-fork:before{content:""}.fa-unlink:before,.fa-chain-broken:before{content:""}.fa-question:before{content:""}.fa-info:before{content:""}.fa-exclamation:before{content:""}.fa-superscript:before{content:""}.fa-subscript:before{content:""}.fa-eraser:before{content:""}.fa-puzzle-piece:before{content:""}.fa-microphone:before{content:""}.fa-microphone-slash:before{content:""}.fa-shield:before{content:""}.fa-calendar-o:before{content:""}.fa-fire-extinguisher:before{content:""}.fa-rocket:before{content:""}.fa-maxcdn:before{content:""}.fa-chevron-circle-left:before{content:""}.fa-chevron-circle-right:before{content:""}.fa-chevron-circle-up:before{content:""}.fa-chevron-circle-down:before{content:""}.fa-html5:before{content:""}.fa-css3:before{content:""}.fa-anchor:before{content:""}.fa-unlock-alt:before{content:""}.fa-bullseye:before{content:""}.fa-ellipsis-h:before{content:""}.fa-ellipsis-v:before{content:""}.fa-rss-square:before{content:""}.fa-play-circle:before{content:""}.fa-ticket:before{content:""}.fa-minus-square:before{content:""}.fa-minus-square-o:before,.wy-menu-vertical li.on a span.toctree-expand:before,.wy-menu-vertical li.current>a span.toctree-expand:before{content:""}.fa-level-up:before{content:""}.fa-level-down:before{content:""}.fa-check-square:before{content:""}.fa-pencil-square:before{content:""}.fa-external-link-square:before{content:""}.fa-share-square:before{content:""}.fa-compass:before{content:""}.fa-toggle-down:before,.fa-caret-square-o-down:before{content:""}.fa-toggle-up:before,.fa-caret-square-o-up:before{content:""}.fa-toggle-right:before,.fa-caret-square-o-right:before{content:""}.fa-euro:before,.fa-eur:before{content:""}.fa-gbp:before{content:""}.fa-dollar:before,.fa-usd:before{content:""}.fa-rupee:before,.fa-inr:before{content:""}.fa-cny:before,.fa-rmb:before,.fa-yen:before,.fa-jpy:before{content:""}.fa-ruble:before,.fa-rouble:before,.fa-rub:before{content:""}.fa-won:before,.fa-krw:before{content:""}.fa-bitcoin:before,.fa-btc:before{content:""}.fa-file:before{content:""}.fa-file-text:before{content:""}.fa-sort-alpha-asc:before{content:""}.fa-sort-alpha-desc:before{content:""}.fa-sort-amount-asc:before{content:""}.fa-sort-amount-desc:before{content:""}.fa-sort-numeric-asc:before{content:""}.fa-sort-numeric-desc:before{content:""}.fa-thumbs-up:before{content:""}.fa-thumbs-down:before{content:""}.fa-youtube-square:before{content:""}.fa-youtube:before{content:""}.fa-xing:before{content:""}.fa-xing-square:before{content:""}.fa-youtube-play:before{content:""}.fa-dropbox:before{content:""}.fa-stack-overflow:before{content:""}.fa-instagram:before{content:""}.fa-flickr:before{content:""}.fa-adn:before{content:""}.fa-bitbucket:before,.icon-bitbucket:before{content:""}.fa-bitbucket-square:before{content:""}.fa-tumblr:before{content:""}.fa-tumblr-square:before{content:""}.fa-long-arrow-down:before{content:""}.fa-long-arrow-up:before{content:""}.fa-long-arrow-left:before{content:""}.fa-long-arrow-right:before{content:""}.fa-apple:before{content:""}.fa-windows:before{content:""}.fa-android:before{content:""}.fa-linux:before{content:""}.fa-dribbble:before{content:""}.fa-skype:before{content:""}.fa-foursquare:before{content:""}.fa-trello:before{content:""}.fa-female:before{content:""}.fa-male:before{content:""}.fa-gittip:before,.fa-gratipay:before{content:""}.fa-sun-o:before{content:""}.fa-moon-o:before{content:""}.fa-archive:before{content:""}.fa-bug:before{content:""}.fa-vk:before{content:""}.fa-weibo:before{content:""}.fa-renren:before{content:""}.fa-pagelines:before{content:""}.fa-stack-exchange:before{content:""}.fa-arrow-circle-o-right:before{content:""}.fa-arrow-circle-o-left:before{content:""}.fa-toggle-left:before,.fa-caret-square-o-left:before{content:""}.fa-dot-circle-o:before{content:""}.fa-wheelchair:before{content:""}.fa-vimeo-square:before{content:""}.fa-turkish-lira:before,.fa-try:before{content:""}.fa-plus-square-o:before,.wy-menu-vertical li span.toctree-expand:before{content:""}.fa-space-shuttle:before{content:""}.fa-slack:before{content:""}.fa-envelope-square:before{content:""}.fa-wordpress:before{content:""}.fa-openid:before{content:""}.fa-institution:before,.fa-bank:before,.fa-university:before{content:""}.fa-mortar-board:before,.fa-graduation-cap:before{content:""}.fa-yahoo:before{content:""}.fa-google:before{content:""}.fa-reddit:before{content:""}.fa-reddit-square:before{content:""}.fa-stumbleupon-circle:before{content:""}.fa-stumbleupon:before{content:""}.fa-delicious:before{content:""}.fa-digg:before{content:""}.fa-pied-piper-pp:before{content:""}.fa-pied-piper-alt:before{content:""}.fa-drupal:before{content:""}.fa-joomla:before{content:""}.fa-language:before{content:""}.fa-fax:before{content:""}.fa-building:before{content:""}.fa-child:before{content:""}.fa-paw:before{content:""}.fa-spoon:before{content:""}.fa-cube:before{content:""}.fa-cubes:before{content:""}.fa-behance:before{content:""}.fa-behance-square:before{content:""}.fa-steam:before{content:""}.fa-steam-square:before{content:""}.fa-recycle:before{content:""}.fa-automobile:before,.fa-car:before{content:""}.fa-cab:before,.fa-taxi:before{content:""}.fa-tree:before{content:""}.fa-spotify:before{content:""}.fa-deviantart:before{content:""}.fa-soundcloud:before{content:""}.fa-database:before{content:""}.fa-file-pdf-o:before{content:""}.fa-file-word-o:before{content:""}.fa-file-excel-o:before{content:""}.fa-file-powerpoint-o:before{content:""}.fa-file-photo-o:before,.fa-file-picture-o:before,.fa-file-image-o:before{content:""}.fa-file-zip-o:before,.fa-file-archive-o:before{content:""}.fa-file-sound-o:before,.fa-file-audio-o:before{content:""}.fa-file-movie-o:before,.fa-file-video-o:before{content:""}.fa-file-code-o:before{content:""}.fa-vine:before{content:""}.fa-codepen:before{content:""}.fa-jsfiddle:before{content:""}.fa-life-bouy:before,.fa-life-buoy:before,.fa-life-saver:before,.fa-support:before,.fa-life-ring:before{content:""}.fa-circle-o-notch:before{content:""}.fa-ra:before,.fa-resistance:before,.fa-rebel:before{content:""}.fa-ge:before,.fa-empire:before{content:""}.fa-git-square:before{content:""}.fa-git:before{content:""}.fa-y-combinator-square:before,.fa-yc-square:before,.fa-hacker-news:before{content:""}.fa-tencent-weibo:before{content:""}.fa-qq:before{content:""}.fa-wechat:before,.fa-weixin:before{content:""}.fa-send:before,.fa-paper-plane:before{content:""}.fa-send-o:before,.fa-paper-plane-o:before{content:""}.fa-history:before{content:""}.fa-circle-thin:before{content:""}.fa-header:before{content:""}.fa-paragraph:before{content:""}.fa-sliders:before{content:""}.fa-share-alt:before{content:""}.fa-share-alt-square:before{content:""}.fa-bomb:before{content:""}.fa-soccer-ball-o:before,.fa-futbol-o:before{content:""}.fa-tty:before{content:""}.fa-binoculars:before{content:""}.fa-plug:before{content:""}.fa-slideshare:before{content:""}.fa-twitch:before{content:""}.fa-yelp:before{content:""}.fa-newspaper-o:before{content:""}.fa-wifi:before{content:""}.fa-calculator:before{content:""}.fa-paypal:before{content:""}.fa-google-wallet:before{content:""}.fa-cc-visa:before{content:""}.fa-cc-mastercard:before{content:""}.fa-cc-discover:before{content:""}.fa-cc-amex:before{content:""}.fa-cc-paypal:before{content:""}.fa-cc-stripe:before{content:""}.fa-bell-slash:before{content:""}.fa-bell-slash-o:before{content:""}.fa-trash:before{content:""}.fa-copyright:before{content:""}.fa-at:before{content:""}.fa-eyedropper:before{content:""}.fa-paint-brush:before{content:""}.fa-birthday-cake:before{content:""}.fa-area-chart:before{content:""}.fa-pie-chart:before{content:""}.fa-line-chart:before{content:""}.fa-lastfm:before{content:""}.fa-lastfm-square:before{content:""}.fa-toggle-off:before{content:""}.fa-toggle-on:before{content:""}.fa-bicycle:before{content:""}.fa-bus:before{content:""}.fa-ioxhost:before{content:""}.fa-angellist:before{content:""}.fa-cc:before{content:""}.fa-shekel:before,.fa-sheqel:before,.fa-ils:before{content:""}.fa-meanpath:before{content:""}.fa-buysellads:before{content:""}.fa-connectdevelop:before{content:""}.fa-dashcube:before{content:""}.fa-forumbee:before{content:""}.fa-leanpub:before{content:""}.fa-sellsy:before{content:""}.fa-shirtsinbulk:before{content:""}.fa-simplybuilt:before{content:""}.fa-skyatlas:before{content:""}.fa-cart-plus:before{content:""}.fa-cart-arrow-down:before{content:""}.fa-diamond:before{content:""}.fa-ship:before{content:""}.fa-user-secret:before{content:""}.fa-motorcycle:before{content:""}.fa-street-view:before{content:""}.fa-heartbeat:before{content:""}.fa-venus:before{content:""}.fa-mars:before{content:""}.fa-mercury:before{content:""}.fa-intersex:before,.fa-transgender:before{content:""}.fa-transgender-alt:before{content:""}.fa-venus-double:before{content:""}.fa-mars-double:before{content:""}.fa-venus-mars:before{content:""}.fa-mars-stroke:before{content:""}.fa-mars-stroke-v:before{content:""}.fa-mars-stroke-h:before{content:""}.fa-neuter:before{content:""}.fa-genderless:before{content:""}.fa-facebook-official:before{content:""}.fa-pinterest-p:before{content:""}.fa-whatsapp:before{content:""}.fa-server:before{content:""}.fa-user-plus:before{content:""}.fa-user-times:before{content:""}.fa-hotel:before,.fa-bed:before{content:""}.fa-viacoin:before{content:""}.fa-train:before{content:""}.fa-subway:before{content:""}.fa-medium:before{content:""}.fa-yc:before,.fa-y-combinator:before{content:""}.fa-optin-monster:before{content:""}.fa-opencart:before{content:""}.fa-expeditedssl:before{content:""}.fa-battery-4:before,.fa-battery:before,.fa-battery-full:before{content:""}.fa-battery-3:before,.fa-battery-three-quarters:before{content:""}.fa-battery-2:before,.fa-battery-half:before{content:""}.fa-battery-1:before,.fa-battery-quarter:before{content:""}.fa-battery-0:before,.fa-battery-empty:before{content:""}.fa-mouse-pointer:before{content:""}.fa-i-cursor:before{content:""}.fa-object-group:before{content:""}.fa-object-ungroup:before{content:""}.fa-sticky-note:before{content:""}.fa-sticky-note-o:before{content:""}.fa-cc-jcb:before{content:""}.fa-cc-diners-club:before{content:""}.fa-clone:before{content:""}.fa-balance-scale:before{content:""}.fa-hourglass-o:before{content:""}.fa-hourglass-1:before,.fa-hourglass-start:before{content:""}.fa-hourglass-2:before,.fa-hourglass-half:before{content:""}.fa-hourglass-3:before,.fa-hourglass-end:before{content:""}.fa-hourglass:before{content:""}.fa-hand-grab-o:before,.fa-hand-rock-o:before{content:""}.fa-hand-stop-o:before,.fa-hand-paper-o:before{content:""}.fa-hand-scissors-o:before{content:""}.fa-hand-lizard-o:before{content:""}.fa-hand-spock-o:before{content:""}.fa-hand-pointer-o:before{content:""}.fa-hand-peace-o:before{content:""}.fa-trademark:before{content:""}.fa-registered:before{content:""}.fa-creative-commons:before{content:""}.fa-gg:before{content:""}.fa-gg-circle:before{content:""}.fa-tripadvisor:before{content:""}.fa-odnoklassniki:before{content:""}.fa-odnoklassniki-square:before{content:""}.fa-get-pocket:before{content:""}.fa-wikipedia-w:before{content:""}.fa-safari:before{content:""}.fa-chrome:before{content:""}.fa-firefox:before{content:""}.fa-opera:before{content:""}.fa-internet-explorer:before{content:""}.fa-tv:before,.fa-television:before{content:""}.fa-contao:before{content:""}.fa-500px:before{content:""}.fa-amazon:before{content:""}.fa-calendar-plus-o:before{content:""}.fa-calendar-minus-o:before{content:""}.fa-calendar-times-o:before{content:""}.fa-calendar-check-o:before{content:""}.fa-industry:before{content:""}.fa-map-pin:before{content:""}.fa-map-signs:before{content:""}.fa-map-o:before{content:""}.fa-map:before{content:""}.fa-commenting:before{content:""}.fa-commenting-o:before{content:""}.fa-houzz:before{content:""}.fa-vimeo:before{content:""}.fa-black-tie:before{content:""}.fa-fonticons:before{content:""}.fa-reddit-alien:before{content:""}.fa-edge:before{content:""}.fa-credit-card-alt:before{content:""}.fa-codiepie:before{content:""}.fa-modx:before{content:""}.fa-fort-awesome:before{content:""}.fa-usb:before{content:""}.fa-product-hunt:before{content:""}.fa-mixcloud:before{content:""}.fa-scribd:before{content:""}.fa-pause-circle:before{content:""}.fa-pause-circle-o:before{content:""}.fa-stop-circle:before{content:""}.fa-stop-circle-o:before{content:""}.fa-shopping-bag:before{content:""}.fa-shopping-basket:before{content:""}.fa-hashtag:before{content:""}.fa-bluetooth:before{content:""}.fa-bluetooth-b:before{content:""}.fa-percent:before{content:""}.fa-gitlab:before,.icon-gitlab:before{content:""}.fa-wpbeginner:before{content:""}.fa-wpforms:before{content:""}.fa-envira:before{content:""}.fa-universal-access:before{content:""}.fa-wheelchair-alt:before{content:""}.fa-question-circle-o:before{content:""}.fa-blind:before{content:""}.fa-audio-description:before{content:""}.fa-volume-control-phone:before{content:""}.fa-braille:before{content:""}.fa-assistive-listening-systems:before{content:""}.fa-asl-interpreting:before,.fa-american-sign-language-interpreting:before{content:""}.fa-deafness:before,.fa-hard-of-hearing:before,.fa-deaf:before{content:""}.fa-glide:before{content:""}.fa-glide-g:before{content:""}.fa-signing:before,.fa-sign-language:before{content:""}.fa-low-vision:before{content:""}.fa-viadeo:before{content:""}.fa-viadeo-square:before{content:""}.fa-snapchat:before{content:""}.fa-snapchat-ghost:before{content:""}.fa-snapchat-square:before{content:""}.fa-pied-piper:before{content:""}.fa-first-order:before{content:""}.fa-yoast:before{content:""}.fa-themeisle:before{content:""}.fa-google-plus-circle:before,.fa-google-plus-official:before{content:""}.fa-fa:before,.fa-font-awesome:before{content:""}.fa-handshake-o:before{content:""}.fa-envelope-open:before{content:""}.fa-envelope-open-o:before{content:""}.fa-linode:before{content:""}.fa-address-book:before{content:""}.fa-address-book-o:before{content:""}.fa-vcard:before,.fa-address-card:before{content:""}.fa-vcard-o:before,.fa-address-card-o:before{content:""}.fa-user-circle:before{content:""}.fa-user-circle-o:before{content:""}.fa-user-o:before{content:""}.fa-id-badge:before{content:""}.fa-drivers-license:before,.fa-id-card:before{content:""}.fa-drivers-license-o:before,.fa-id-card-o:before{content:""}.fa-quora:before{content:""}.fa-free-code-camp:before{content:""}.fa-telegram:before{content:""}.fa-thermometer-4:before,.fa-thermometer:before,.fa-thermometer-full:before{content:""}.fa-thermometer-3:before,.fa-thermometer-three-quarters:before{content:""}.fa-thermometer-2:before,.fa-thermometer-half:before{content:""}.fa-thermometer-1:before,.fa-thermometer-quarter:before{content:""}.fa-thermometer-0:before,.fa-thermometer-empty:before{content:""}.fa-shower:before{content:""}.fa-bathtub:before,.fa-s15:before,.fa-bath:before{content:""}.fa-podcast:before{content:""}.fa-window-maximize:before{content:""}.fa-window-minimize:before{content:""}.fa-window-restore:before{content:""}.fa-times-rectangle:before,.fa-window-close:before{content:""}.fa-times-rectangle-o:before,.fa-window-close-o:before{content:""}.fa-bandcamp:before{content:""}.fa-grav:before{content:""}.fa-etsy:before{content:""}.fa-imdb:before{content:""}.fa-ravelry:before{content:""}.fa-eercast:before{content:""}.fa-microchip:before{content:""}.fa-snowflake-o:before{content:""}.fa-superpowers:before{content:""}.fa-wpexplorer:before{content:""}.fa-meetup:before{content:""}.sr-only{position:absolute;width:1px;height:1px;padding:0;margin:-1px;overflow:hidden;clip:rect(0, 0, 0, 0);border:0}.sr-only-focusable:active,.sr-only-focusable:focus{position:static;width:auto;height:auto;margin:0;overflow:visible;clip:auto}.fa,.wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.current>a span.toctree-expand,.rst-content .admonition-title,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content dl dt .headerlink,.rst-content p.caption .headerlink,.rst-content table>caption .headerlink,.rst-content .code-block-caption .headerlink,.rst-content tt.download span:first-child,.rst-content code.download span:first-child,.icon,.wy-dropdown .caret,.wy-inline-validate.wy-inline-validate-success .wy-input-context,.wy-inline-validate.wy-inline-validate-danger .wy-input-context,.wy-inline-validate.wy-inline-validate-warning .wy-input-context,.wy-inline-validate.wy-inline-validate-info .wy-input-context{font-family:inherit}.fa:before,.wy-menu-vertical li span.toctree-expand:before,.wy-menu-vertical li.on a span.toctree-expand:before,.wy-menu-vertical li.current>a span.toctree-expand:before,.rst-content .admonition-title:before,.rst-content h1 .headerlink:before,.rst-content h2 .headerlink:before,.rst-content h3 .headerlink:before,.rst-content h4 .headerlink:before,.rst-content h5 .headerlink:before,.rst-content h6 .headerlink:before,.rst-content dl dt .headerlink:before,.rst-content p.caption .headerlink:before,.rst-content table>caption .headerlink:before,.rst-content .code-block-caption .headerlink:before,.rst-content tt.download span:first-child:before,.rst-content code.download span:first-child:before,.icon:before,.wy-dropdown .caret:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before{font-family:"FontAwesome";display:inline-block;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa,a .wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li a span.toctree-expand,.wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.current>a span.toctree-expand,a .rst-content .admonition-title,.rst-content a .admonition-title,a .rst-content h1 .headerlink,.rst-content h1 a .headerlink,a .rst-content h2 .headerlink,.rst-content h2 a .headerlink,a .rst-content h3 .headerlink,.rst-content h3 a .headerlink,a .rst-content h4 .headerlink,.rst-content h4 a .headerlink,a .rst-content h5 .headerlink,.rst-content h5 a .headerlink,a .rst-content h6 .headerlink,.rst-content h6 a .headerlink,a .rst-content dl dt .headerlink,.rst-content dl dt a .headerlink,a .rst-content p.caption .headerlink,.rst-content p.caption a .headerlink,a .rst-content table>caption .headerlink,.rst-content table>caption a .headerlink,a .rst-content .code-block-caption .headerlink,.rst-content .code-block-caption a .headerlink,a .rst-content tt.download span:first-child,.rst-content tt.download a span:first-child,a .rst-content code.download span:first-child,.rst-content code.download a span:first-child,a .icon{display:inline-block;text-decoration:inherit}.btn .fa,.btn .wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li .btn span.toctree-expand,.btn .wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.on a .btn span.toctree-expand,.btn .wy-menu-vertical li.current>a span.toctree-expand,.wy-menu-vertical li.current>a .btn span.toctree-expand,.btn .rst-content .admonition-title,.rst-content .btn .admonition-title,.btn .rst-content h1 .headerlink,.rst-content h1 .btn .headerlink,.btn .rst-content h2 .headerlink,.rst-content h2 .btn .headerlink,.btn .rst-content h3 .headerlink,.rst-content h3 .btn .headerlink,.btn .rst-content h4 .headerlink,.rst-content h4 .btn .headerlink,.btn .rst-content h5 .headerlink,.rst-content h5 .btn .headerlink,.btn .rst-content h6 .headerlink,.rst-content h6 .btn .headerlink,.btn .rst-content dl dt .headerlink,.rst-content dl dt .btn .headerlink,.btn .rst-content p.caption .headerlink,.rst-content p.caption .btn .headerlink,.btn .rst-content table>caption .headerlink,.rst-content table>caption .btn .headerlink,.btn .rst-content .code-block-caption .headerlink,.rst-content .code-block-caption .btn .headerlink,.btn .rst-content tt.download span:first-child,.rst-content tt.download .btn span:first-child,.btn .rst-content code.download span:first-child,.rst-content code.download .btn span:first-child,.btn .icon,.nav .fa,.nav .wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li .nav span.toctree-expand,.nav .wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.on a .nav span.toctree-expand,.nav .wy-menu-vertical li.current>a span.toctree-expand,.wy-menu-vertical li.current>a .nav span.toctree-expand,.nav .rst-content .admonition-title,.rst-content .nav .admonition-title,.nav .rst-content h1 .headerlink,.rst-content h1 .nav .headerlink,.nav .rst-content h2 .headerlink,.rst-content h2 .nav .headerlink,.nav .rst-content h3 .headerlink,.rst-content h3 .nav .headerlink,.nav .rst-content h4 .headerlink,.rst-content h4 .nav .headerlink,.nav .rst-content h5 .headerlink,.rst-content h5 .nav .headerlink,.nav .rst-content h6 .headerlink,.rst-content h6 .nav .headerlink,.nav .rst-content dl dt .headerlink,.rst-content dl dt .nav .headerlink,.nav .rst-content p.caption .headerlink,.rst-content p.caption .nav .headerlink,.nav .rst-content table>caption .headerlink,.rst-content table>caption .nav .headerlink,.nav .rst-content .code-block-caption .headerlink,.rst-content .code-block-caption .nav .headerlink,.nav .rst-content tt.download span:first-child,.rst-content tt.download .nav span:first-child,.nav .rst-content code.download span:first-child,.rst-content code.download .nav span:first-child,.nav .icon{display:inline}.btn .fa.fa-large,.btn .wy-menu-vertical li span.fa-large.toctree-expand,.wy-menu-vertical li .btn span.fa-large.toctree-expand,.btn .rst-content .fa-large.admonition-title,.rst-content .btn .fa-large.admonition-title,.btn .rst-content h1 .fa-large.headerlink,.rst-content h1 .btn .fa-large.headerlink,.btn .rst-content h2 .fa-large.headerlink,.rst-content h2 .btn .fa-large.headerlink,.btn .rst-content h3 .fa-large.headerlink,.rst-content h3 .btn .fa-large.headerlink,.btn .rst-content h4 .fa-large.headerlink,.rst-content h4 .btn .fa-large.headerlink,.btn .rst-content h5 .fa-large.headerlink,.rst-content h5 .btn .fa-large.headerlink,.btn .rst-content h6 .fa-large.headerlink,.rst-content h6 .btn .fa-large.headerlink,.btn .rst-content dl dt .fa-large.headerlink,.rst-content dl dt .btn .fa-large.headerlink,.btn .rst-content p.caption .fa-large.headerlink,.rst-content p.caption .btn .fa-large.headerlink,.btn .rst-content table>caption .fa-large.headerlink,.rst-content table>caption .btn .fa-large.headerlink,.btn .rst-content .code-block-caption .fa-large.headerlink,.rst-content .code-block-caption .btn .fa-large.headerlink,.btn .rst-content tt.download span.fa-large:first-child,.rst-content tt.download .btn span.fa-large:first-child,.btn .rst-content code.download span.fa-large:first-child,.rst-content code.download .btn span.fa-large:first-child,.btn .fa-large.icon,.nav .fa.fa-large,.nav .wy-menu-vertical li span.fa-large.toctree-expand,.wy-menu-vertical li .nav span.fa-large.toctree-expand,.nav .rst-content .fa-large.admonition-title,.rst-content .nav .fa-large.admonition-title,.nav .rst-content h1 .fa-large.headerlink,.rst-content h1 .nav .fa-large.headerlink,.nav .rst-content h2 .fa-large.headerlink,.rst-content h2 .nav .fa-large.headerlink,.nav .rst-content h3 .fa-large.headerlink,.rst-content h3 .nav .fa-large.headerlink,.nav .rst-content h4 .fa-large.headerlink,.rst-content h4 .nav .fa-large.headerlink,.nav .rst-content h5 .fa-large.headerlink,.rst-content h5 .nav .fa-large.headerlink,.nav .rst-content h6 .fa-large.headerlink,.rst-content h6 .nav .fa-large.headerlink,.nav .rst-content dl dt .fa-large.headerlink,.rst-content dl dt .nav .fa-large.headerlink,.nav .rst-content p.caption .fa-large.headerlink,.rst-content p.caption .nav .fa-large.headerlink,.nav .rst-content table>caption .fa-large.headerlink,.rst-content table>caption .nav .fa-large.headerlink,.nav .rst-content .code-block-caption .fa-large.headerlink,.rst-content .code-block-caption .nav .fa-large.headerlink,.nav .rst-content tt.download span.fa-large:first-child,.rst-content tt.download .nav span.fa-large:first-child,.nav .rst-content code.download span.fa-large:first-child,.rst-content code.download .nav span.fa-large:first-child,.nav .fa-large.icon{line-height:.9em}.btn .fa.fa-spin,.btn .wy-menu-vertical li span.fa-spin.toctree-expand,.wy-menu-vertical li .btn span.fa-spin.toctree-expand,.btn .rst-content .fa-spin.admonition-title,.rst-content .btn .fa-spin.admonition-title,.btn .rst-content h1 .fa-spin.headerlink,.rst-content h1 .btn .fa-spin.headerlink,.btn .rst-content h2 .fa-spin.headerlink,.rst-content h2 .btn .fa-spin.headerlink,.btn .rst-content h3 .fa-spin.headerlink,.rst-content h3 .btn .fa-spin.headerlink,.btn .rst-content h4 .fa-spin.headerlink,.rst-content h4 .btn .fa-spin.headerlink,.btn .rst-content h5 .fa-spin.headerlink,.rst-content h5 .btn .fa-spin.headerlink,.btn .rst-content h6 .fa-spin.headerlink,.rst-content h6 .btn .fa-spin.headerlink,.btn .rst-content dl dt .fa-spin.headerlink,.rst-content dl dt .btn .fa-spin.headerlink,.btn .rst-content p.caption .fa-spin.headerlink,.rst-content p.caption .btn .fa-spin.headerlink,.btn .rst-content table>caption .fa-spin.headerlink,.rst-content table>caption .btn .fa-spin.headerlink,.btn .rst-content .code-block-caption .fa-spin.headerlink,.rst-content .code-block-caption .btn .fa-spin.headerlink,.btn .rst-content tt.download span.fa-spin:first-child,.rst-content tt.download .btn span.fa-spin:first-child,.btn .rst-content code.download span.fa-spin:first-child,.rst-content code.download .btn span.fa-spin:first-child,.btn .fa-spin.icon,.nav .fa.fa-spin,.nav .wy-menu-vertical li span.fa-spin.toctree-expand,.wy-menu-vertical li .nav span.fa-spin.toctree-expand,.nav .rst-content .fa-spin.admonition-title,.rst-content .nav .fa-spin.admonition-title,.nav .rst-content h1 .fa-spin.headerlink,.rst-content h1 .nav .fa-spin.headerlink,.nav .rst-content h2 .fa-spin.headerlink,.rst-content h2 .nav .fa-spin.headerlink,.nav .rst-content h3 .fa-spin.headerlink,.rst-content h3 .nav .fa-spin.headerlink,.nav .rst-content h4 .fa-spin.headerlink,.rst-content h4 .nav .fa-spin.headerlink,.nav .rst-content h5 .fa-spin.headerlink,.rst-content h5 .nav .fa-spin.headerlink,.nav .rst-content h6 .fa-spin.headerlink,.rst-content h6 .nav .fa-spin.headerlink,.nav .rst-content dl dt .fa-spin.headerlink,.rst-content dl dt .nav .fa-spin.headerlink,.nav .rst-content p.caption .fa-spin.headerlink,.rst-content p.caption .nav .fa-spin.headerlink,.nav .rst-content table>caption .fa-spin.headerlink,.rst-content table>caption .nav .fa-spin.headerlink,.nav .rst-content .code-block-caption .fa-spin.headerlink,.rst-content .code-block-caption .nav .fa-spin.headerlink,.nav .rst-content tt.download span.fa-spin:first-child,.rst-content tt.download .nav span.fa-spin:first-child,.nav .rst-content code.download span.fa-spin:first-child,.rst-content code.download .nav span.fa-spin:first-child,.nav .fa-spin.icon{display:inline-block}.btn.fa:before,.wy-menu-vertical li span.btn.toctree-expand:before,.rst-content .btn.admonition-title:before,.rst-content h1 .btn.headerlink:before,.rst-content h2 .btn.headerlink:before,.rst-content h3 .btn.headerlink:before,.rst-content h4 .btn.headerlink:before,.rst-content h5 .btn.headerlink:before,.rst-content h6 .btn.headerlink:before,.rst-content dl dt .btn.headerlink:before,.rst-content p.caption .btn.headerlink:before,.rst-content table>caption .btn.headerlink:before,.rst-content .code-block-caption .btn.headerlink:before,.rst-content tt.download span.btn:first-child:before,.rst-content code.download span.btn:first-child:before,.btn.icon:before{opacity:.5;-webkit-transition:opacity .05s ease-in;-moz-transition:opacity .05s ease-in;transition:opacity .05s ease-in}.btn.fa:hover:before,.wy-menu-vertical li span.btn.toctree-expand:hover:before,.rst-content .btn.admonition-title:hover:before,.rst-content h1 .btn.headerlink:hover:before,.rst-content h2 .btn.headerlink:hover:before,.rst-content h3 .btn.headerlink:hover:before,.rst-content h4 .btn.headerlink:hover:before,.rst-content h5 .btn.headerlink:hover:before,.rst-content h6 .btn.headerlink:hover:before,.rst-content dl dt .btn.headerlink:hover:before,.rst-content p.caption .btn.headerlink:hover:before,.rst-content table>caption .btn.headerlink:hover:before,.rst-content .code-block-caption .btn.headerlink:hover:before,.rst-content tt.download span.btn:first-child:hover:before,.rst-content code.download span.btn:first-child:hover:before,.btn.icon:hover:before{opacity:1}.btn-mini .fa:before,.btn-mini .wy-menu-vertical li span.toctree-expand:before,.wy-menu-vertical li .btn-mini span.toctree-expand:before,.btn-mini .rst-content .admonition-title:before,.rst-content .btn-mini .admonition-title:before,.btn-mini .rst-content h1 .headerlink:before,.rst-content h1 .btn-mini .headerlink:before,.btn-mini .rst-content h2 .headerlink:before,.rst-content h2 .btn-mini .headerlink:before,.btn-mini .rst-content h3 .headerlink:before,.rst-content h3 .btn-mini .headerlink:before,.btn-mini .rst-content h4 .headerlink:before,.rst-content h4 .btn-mini .headerlink:before,.btn-mini .rst-content h5 .headerlink:before,.rst-content h5 .btn-mini .headerlink:before,.btn-mini .rst-content h6 .headerlink:before,.rst-content h6 .btn-mini .headerlink:before,.btn-mini .rst-content dl dt .headerlink:before,.rst-content dl dt .btn-mini .headerlink:before,.btn-mini .rst-content p.caption .headerlink:before,.rst-content p.caption .btn-mini .headerlink:before,.btn-mini .rst-content table>caption .headerlink:before,.rst-content table>caption .btn-mini .headerlink:before,.btn-mini .rst-content .code-block-caption .headerlink:before,.rst-content .code-block-caption .btn-mini .headerlink:before,.btn-mini .rst-content tt.download span:first-child:before,.rst-content tt.download .btn-mini span:first-child:before,.btn-mini .rst-content code.download span:first-child:before,.rst-content code.download .btn-mini span:first-child:before,.btn-mini .icon:before{font-size:14px;vertical-align:-15%}.wy-alert,.rst-content .note,.rst-content .attention,.rst-content .caution,.rst-content .danger,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .tip,.rst-content .warning,.rst-content .seealso,.rst-content .admonition-todo,.rst-content .admonition{padding:12px;line-height:24px;margin-bottom:24px;background:#e7f2fa}.wy-alert-title,.rst-content .admonition-title{color:#fff;font-weight:bold;display:block;color:#fff;background:#6ab0de;margin:-12px;padding:6px 12px;margin-bottom:12px}.wy-alert.wy-alert-danger,.rst-content .wy-alert-danger.note,.rst-content .wy-alert-danger.attention,.rst-content .wy-alert-danger.caution,.rst-content .danger,.rst-content .error,.rst-content .wy-alert-danger.hint,.rst-content .wy-alert-danger.important,.rst-content .wy-alert-danger.tip,.rst-content .wy-alert-danger.warning,.rst-content .wy-alert-danger.seealso,.rst-content .wy-alert-danger.admonition-todo,.rst-content .wy-alert-danger.admonition{background:#fdf3f2}.wy-alert.wy-alert-danger .wy-alert-title,.rst-content .wy-alert-danger.note .wy-alert-title,.rst-content .wy-alert-danger.attention .wy-alert-title,.rst-content .wy-alert-danger.caution .wy-alert-title,.rst-content .danger .wy-alert-title,.rst-content .error .wy-alert-title,.rst-content .wy-alert-danger.hint .wy-alert-title,.rst-content .wy-alert-danger.important .wy-alert-title,.rst-content .wy-alert-danger.tip .wy-alert-title,.rst-content .wy-alert-danger.warning .wy-alert-title,.rst-content .wy-alert-danger.seealso .wy-alert-title,.rst-content .wy-alert-danger.admonition-todo .wy-alert-title,.rst-content .wy-alert-danger.admonition .wy-alert-title,.wy-alert.wy-alert-danger .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-danger .admonition-title,.rst-content .wy-alert-danger.note .admonition-title,.rst-content .wy-alert-danger.attention .admonition-title,.rst-content .wy-alert-danger.caution .admonition-title,.rst-content .danger .admonition-title,.rst-content .error .admonition-title,.rst-content .wy-alert-danger.hint .admonition-title,.rst-content .wy-alert-danger.important .admonition-title,.rst-content .wy-alert-danger.tip .admonition-title,.rst-content .wy-alert-danger.warning .admonition-title,.rst-content .wy-alert-danger.seealso .admonition-title,.rst-content .wy-alert-danger.admonition-todo .admonition-title,.rst-content .wy-alert-danger.admonition .admonition-title{background:#f29f97}.wy-alert.wy-alert-warning,.rst-content .wy-alert-warning.note,.rst-content .attention,.rst-content .caution,.rst-content .wy-alert-warning.danger,.rst-content .wy-alert-warning.error,.rst-content .wy-alert-warning.hint,.rst-content .wy-alert-warning.important,.rst-content .wy-alert-warning.tip,.rst-content .warning,.rst-content .wy-alert-warning.seealso,.rst-content .admonition-todo,.rst-content .wy-alert-warning.admonition{background:#ffedcc}.wy-alert.wy-alert-warning .wy-alert-title,.rst-content .wy-alert-warning.note .wy-alert-title,.rst-content .attention .wy-alert-title,.rst-content .caution .wy-alert-title,.rst-content .wy-alert-warning.danger .wy-alert-title,.rst-content .wy-alert-warning.error .wy-alert-title,.rst-content .wy-alert-warning.hint .wy-alert-title,.rst-content .wy-alert-warning.important .wy-alert-title,.rst-content .wy-alert-warning.tip .wy-alert-title,.rst-content .warning .wy-alert-title,.rst-content .wy-alert-warning.seealso .wy-alert-title,.rst-content .admonition-todo .wy-alert-title,.rst-content .wy-alert-warning.admonition .wy-alert-title,.wy-alert.wy-alert-warning .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-warning .admonition-title,.rst-content .wy-alert-warning.note .admonition-title,.rst-content .attention .admonition-title,.rst-content .caution .admonition-title,.rst-content .wy-alert-warning.danger .admonition-title,.rst-content .wy-alert-warning.error .admonition-title,.rst-content .wy-alert-warning.hint .admonition-title,.rst-content .wy-alert-warning.important .admonition-title,.rst-content .wy-alert-warning.tip .admonition-title,.rst-content .warning .admonition-title,.rst-content .wy-alert-warning.seealso .admonition-title,.rst-content .admonition-todo .admonition-title,.rst-content .wy-alert-warning.admonition .admonition-title{background:#f0b37e}.wy-alert.wy-alert-info,.rst-content .note,.rst-content .wy-alert-info.attention,.rst-content .wy-alert-info.caution,.rst-content .wy-alert-info.danger,.rst-content .wy-alert-info.error,.rst-content .wy-alert-info.hint,.rst-content .wy-alert-info.important,.rst-content .wy-alert-info.tip,.rst-content .wy-alert-info.warning,.rst-content .seealso,.rst-content .wy-alert-info.admonition-todo,.rst-content .wy-alert-info.admonition{background:#e7f2fa}.wy-alert.wy-alert-info .wy-alert-title,.rst-content .note .wy-alert-title,.rst-content .wy-alert-info.attention .wy-alert-title,.rst-content .wy-alert-info.caution .wy-alert-title,.rst-content .wy-alert-info.danger .wy-alert-title,.rst-content .wy-alert-info.error .wy-alert-title,.rst-content .wy-alert-info.hint .wy-alert-title,.rst-content .wy-alert-info.important .wy-alert-title,.rst-content .wy-alert-info.tip .wy-alert-title,.rst-content .wy-alert-info.warning .wy-alert-title,.rst-content .seealso .wy-alert-title,.rst-content .wy-alert-info.admonition-todo .wy-alert-title,.rst-content .wy-alert-info.admonition .wy-alert-title,.wy-alert.wy-alert-info .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-info .admonition-title,.rst-content .note .admonition-title,.rst-content .wy-alert-info.attention .admonition-title,.rst-content .wy-alert-info.caution .admonition-title,.rst-content .wy-alert-info.danger .admonition-title,.rst-content .wy-alert-info.error .admonition-title,.rst-content .wy-alert-info.hint .admonition-title,.rst-content .wy-alert-info.important .admonition-title,.rst-content .wy-alert-info.tip .admonition-title,.rst-content .wy-alert-info.warning .admonition-title,.rst-content .seealso .admonition-title,.rst-content .wy-alert-info.admonition-todo .admonition-title,.rst-content .wy-alert-info.admonition .admonition-title{background:#6ab0de}.wy-alert.wy-alert-success,.rst-content .wy-alert-success.note,.rst-content .wy-alert-success.attention,.rst-content .wy-alert-success.caution,.rst-content .wy-alert-success.danger,.rst-content .wy-alert-success.error,.rst-content .hint,.rst-content .important,.rst-content .tip,.rst-content .wy-alert-success.warning,.rst-content .wy-alert-success.seealso,.rst-content .wy-alert-success.admonition-todo,.rst-content .wy-alert-success.admonition{background:#dbfaf4}.wy-alert.wy-alert-success .wy-alert-title,.rst-content .wy-alert-success.note .wy-alert-title,.rst-content .wy-alert-success.attention .wy-alert-title,.rst-content .wy-alert-success.caution .wy-alert-title,.rst-content .wy-alert-success.danger .wy-alert-title,.rst-content .wy-alert-success.error .wy-alert-title,.rst-content .hint .wy-alert-title,.rst-content .important .wy-alert-title,.rst-content .tip .wy-alert-title,.rst-content .wy-alert-success.warning .wy-alert-title,.rst-content .wy-alert-success.seealso .wy-alert-title,.rst-content .wy-alert-success.admonition-todo .wy-alert-title,.rst-content .wy-alert-success.admonition .wy-alert-title,.wy-alert.wy-alert-success .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-success .admonition-title,.rst-content .wy-alert-success.note .admonition-title,.rst-content .wy-alert-success.attention .admonition-title,.rst-content .wy-alert-success.caution .admonition-title,.rst-content .wy-alert-success.danger .admonition-title,.rst-content .wy-alert-success.error .admonition-title,.rst-content .hint .admonition-title,.rst-content .important .admonition-title,.rst-content .tip .admonition-title,.rst-content .wy-alert-success.warning .admonition-title,.rst-content .wy-alert-success.seealso .admonition-title,.rst-content .wy-alert-success.admonition-todo .admonition-title,.rst-content .wy-alert-success.admonition .admonition-title{background:#1abc9c}.wy-alert.wy-alert-neutral,.rst-content .wy-alert-neutral.note,.rst-content .wy-alert-neutral.attention,.rst-content .wy-alert-neutral.caution,.rst-content .wy-alert-neutral.danger,.rst-content .wy-alert-neutral.error,.rst-content .wy-alert-neutral.hint,.rst-content .wy-alert-neutral.important,.rst-content .wy-alert-neutral.tip,.rst-content .wy-alert-neutral.warning,.rst-content .wy-alert-neutral.seealso,.rst-content .wy-alert-neutral.admonition-todo,.rst-content .wy-alert-neutral.admonition{background:#f3f6f6}.wy-alert.wy-alert-neutral .wy-alert-title,.rst-content .wy-alert-neutral.note .wy-alert-title,.rst-content .wy-alert-neutral.attention .wy-alert-title,.rst-content .wy-alert-neutral.caution .wy-alert-title,.rst-content .wy-alert-neutral.danger .wy-alert-title,.rst-content .wy-alert-neutral.error .wy-alert-title,.rst-content .wy-alert-neutral.hint .wy-alert-title,.rst-content .wy-alert-neutral.important .wy-alert-title,.rst-content .wy-alert-neutral.tip .wy-alert-title,.rst-content .wy-alert-neutral.warning .wy-alert-title,.rst-content .wy-alert-neutral.seealso .wy-alert-title,.rst-content .wy-alert-neutral.admonition-todo .wy-alert-title,.rst-content .wy-alert-neutral.admonition .wy-alert-title,.wy-alert.wy-alert-neutral .rst-content .admonition-title,.rst-content .wy-alert.wy-alert-neutral .admonition-title,.rst-content .wy-alert-neutral.note .admonition-title,.rst-content .wy-alert-neutral.attention .admonition-title,.rst-content .wy-alert-neutral.caution .admonition-title,.rst-content .wy-alert-neutral.danger .admonition-title,.rst-content .wy-alert-neutral.error .admonition-title,.rst-content .wy-alert-neutral.hint .admonition-title,.rst-content .wy-alert-neutral.important .admonition-title,.rst-content .wy-alert-neutral.tip .admonition-title,.rst-content .wy-alert-neutral.warning .admonition-title,.rst-content .wy-alert-neutral.seealso .admonition-title,.rst-content .wy-alert-neutral.admonition-todo .admonition-title,.rst-content .wy-alert-neutral.admonition .admonition-title{color:#404040;background:#e1e4e5}.wy-alert.wy-alert-neutral a,.rst-content .wy-alert-neutral.note a,.rst-content .wy-alert-neutral.attention a,.rst-content .wy-alert-neutral.caution a,.rst-content .wy-alert-neutral.danger a,.rst-content .wy-alert-neutral.error a,.rst-content .wy-alert-neutral.hint a,.rst-content .wy-alert-neutral.important a,.rst-content .wy-alert-neutral.tip a,.rst-content .wy-alert-neutral.warning a,.rst-content .wy-alert-neutral.seealso a,.rst-content .wy-alert-neutral.admonition-todo a,.rst-content .wy-alert-neutral.admonition a{color:#2980B9}.wy-alert p:last-child,.rst-content .note p:last-child,.rst-content .attention p:last-child,.rst-content .caution p:last-child,.rst-content .danger p:last-child,.rst-content .error p:last-child,.rst-content .hint p:last-child,.rst-content .important p:last-child,.rst-content .tip p:last-child,.rst-content .warning p:last-child,.rst-content .seealso p:last-child,.rst-content .admonition-todo p:last-child,.rst-content .admonition p:last-child{margin-bottom:0}.wy-tray-container{position:fixed;bottom:0px;left:0;z-index:600}.wy-tray-container li{display:block;width:300px;background:transparent;color:#fff;text-align:center;box-shadow:0 5px 5px 0 rgba(0,0,0,0.1);padding:0 24px;min-width:20%;opacity:0;height:0;line-height:56px;overflow:hidden;-webkit-transition:all .3s ease-in;-moz-transition:all .3s ease-in;transition:all .3s ease-in}.wy-tray-container li.wy-tray-item-success{background:#27AE60}.wy-tray-container li.wy-tray-item-info{background:#2980B9}.wy-tray-container li.wy-tray-item-warning{background:#E67E22}.wy-tray-container li.wy-tray-item-danger{background:#E74C3C}.wy-tray-container li.on{opacity:1;height:56px}@media screen and (max-width: 768px){.wy-tray-container{bottom:auto;top:0;width:100%}.wy-tray-container li{width:100%}}button{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle;cursor:pointer;line-height:normal;-webkit-appearance:button;*overflow:visible}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}button[disabled]{cursor:default}.btn{display:inline-block;border-radius:2px;line-height:normal;white-space:nowrap;text-align:center;cursor:pointer;font-size:100%;padding:6px 12px 8px 12px;color:#fff;border:1px solid rgba(0,0,0,0.1);background-color:#27AE60;text-decoration:none;font-weight:normal;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;box-shadow:0px 1px 2px -1px rgba(255,255,255,0.5) inset,0px -2px 0px 0px rgba(0,0,0,0.1) inset;outline-none:false;vertical-align:middle;*display:inline;zoom:1;-webkit-user-drag:none;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;-webkit-transition:all .1s linear;-moz-transition:all .1s linear;transition:all .1s linear}.btn-hover{background:#2e8ece;color:#fff}.btn:hover{background:#2cc36b;color:#fff}.btn:focus{background:#2cc36b;outline:0}.btn:active{box-shadow:0px -1px 0px 0px rgba(0,0,0,0.05) inset,0px 2px 0px 0px rgba(0,0,0,0.1) inset;padding:8px 12px 6px 12px}.btn:visited{color:#fff}.btn:disabled{background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);filter:alpha(opacity=40);opacity:.4;cursor:not-allowed;box-shadow:none}.btn-disabled{background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);filter:alpha(opacity=40);opacity:.4;cursor:not-allowed;box-shadow:none}.btn-disabled:hover,.btn-disabled:focus,.btn-disabled:active{background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);filter:alpha(opacity=40);opacity:.4;cursor:not-allowed;box-shadow:none}.btn::-moz-focus-inner{padding:0;border:0}.btn-small{font-size:80%}.btn-info{background-color:#2980B9 !important}.btn-info:hover{background-color:#2e8ece !important}.btn-neutral{background-color:#f3f6f6 !important;color:#404040 !important}.btn-neutral:hover{background-color:#e5ebeb !important;color:#404040}.btn-neutral:visited{color:#404040 !important}.btn-success{background-color:#27AE60 !important}.btn-success:hover{background-color:#295 !important}.btn-danger{background-color:#E74C3C !important}.btn-danger:hover{background-color:#ea6153 !important}.btn-warning{background-color:#E67E22 !important}.btn-warning:hover{background-color:#e98b39 !important}.btn-invert{background-color:#222}.btn-invert:hover{background-color:#2f2f2f !important}.btn-link{background-color:transparent !important;color:#2980B9;box-shadow:none;border-color:transparent !important}.btn-link:hover{background-color:transparent !important;color:#409ad5 !important;box-shadow:none}.btn-link:active{background-color:transparent !important;color:#409ad5 !important;box-shadow:none}.btn-link:visited{color:#9B59B6}.wy-btn-group .btn,.wy-control .btn{vertical-align:middle}.wy-btn-group{margin-bottom:24px;*zoom:1}.wy-btn-group:before,.wy-btn-group:after{display:table;content:""}.wy-btn-group:after{clear:both}.wy-dropdown{position:relative;display:inline-block}.wy-dropdown-active .wy-dropdown-menu{display:block}.wy-dropdown-menu{position:absolute;left:0;display:none;float:left;top:100%;min-width:100%;background:#fcfcfc;z-index:100;border:solid 1px #cfd7dd;box-shadow:0 2px 2px 0 rgba(0,0,0,0.1);padding:12px}.wy-dropdown-menu>dd>a{display:block;clear:both;color:#404040;white-space:nowrap;font-size:90%;padding:0 12px;cursor:pointer}.wy-dropdown-menu>dd>a:hover{background:#2980B9;color:#fff}.wy-dropdown-menu>dd.divider{border-top:solid 1px #cfd7dd;margin:6px 0}.wy-dropdown-menu>dd.search{padding-bottom:12px}.wy-dropdown-menu>dd.search input[type="search"]{width:100%}.wy-dropdown-menu>dd.call-to-action{background:#e3e3e3;text-transform:uppercase;font-weight:500;font-size:80%}.wy-dropdown-menu>dd.call-to-action:hover{background:#e3e3e3}.wy-dropdown-menu>dd.call-to-action .btn{color:#fff}.wy-dropdown.wy-dropdown-up .wy-dropdown-menu{bottom:100%;top:auto;left:auto;right:0}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu{background:#fcfcfc;margin-top:2px}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu a{padding:6px 12px}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu a:hover{background:#2980B9;color:#fff}.wy-dropdown.wy-dropdown-left .wy-dropdown-menu{right:0;left:auto;text-align:right}.wy-dropdown-arrow:before{content:" ";border-bottom:5px solid #f5f5f5;border-left:5px solid transparent;border-right:5px solid transparent;position:absolute;display:block;top:-4px;left:50%;margin-left:-3px}.wy-dropdown-arrow.wy-dropdown-arrow-left:before{left:11px}.wy-form-stacked select{display:block}.wy-form-aligned input,.wy-form-aligned textarea,.wy-form-aligned select,.wy-form-aligned .wy-help-inline,.wy-form-aligned label{display:inline-block;*display:inline;*zoom:1;vertical-align:middle}.wy-form-aligned .wy-control-group>label{display:inline-block;vertical-align:middle;width:10em;margin:6px 12px 0 0;float:left}.wy-form-aligned .wy-control{float:left}.wy-form-aligned .wy-control label{display:block}.wy-form-aligned .wy-control select{margin-top:6px}fieldset{border:0;margin:0;padding:0}legend{display:block;width:100%;border:0;padding:0;white-space:normal;margin-bottom:24px;font-size:150%;*margin-left:-7px}label{display:block;margin:0 0 .3125em 0;color:#333;font-size:90%}input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}.wy-control-group{margin-bottom:24px;*zoom:1;max-width:68em;margin-left:auto;margin-right:auto;*zoom:1}.wy-control-group:before,.wy-control-group:after{display:table;content:""}.wy-control-group:after{clear:both}.wy-control-group:before,.wy-control-group:after{display:table;content:""}.wy-control-group:after{clear:both}.wy-control-group.wy-control-group-required>label:after{content:" *";color:#E74C3C}.wy-control-group .wy-form-full,.wy-control-group .wy-form-halves,.wy-control-group .wy-form-thirds{padding-bottom:12px}.wy-control-group .wy-form-full select,.wy-control-group .wy-form-halves select,.wy-control-group .wy-form-thirds select{width:100%}.wy-control-group .wy-form-full input[type="text"],.wy-control-group .wy-form-full input[type="password"],.wy-control-group .wy-form-full input[type="email"],.wy-control-group .wy-form-full input[type="url"],.wy-control-group .wy-form-full input[type="date"],.wy-control-group .wy-form-full input[type="month"],.wy-control-group .wy-form-full input[type="time"],.wy-control-group .wy-form-full input[type="datetime"],.wy-control-group .wy-form-full input[type="datetime-local"],.wy-control-group .wy-form-full input[type="week"],.wy-control-group .wy-form-full input[type="number"],.wy-control-group .wy-form-full input[type="search"],.wy-control-group .wy-form-full input[type="tel"],.wy-control-group .wy-form-full input[type="color"],.wy-control-group .wy-form-halves input[type="text"],.wy-control-group .wy-form-halves input[type="password"],.wy-control-group .wy-form-halves input[type="email"],.wy-control-group .wy-form-halves input[type="url"],.wy-control-group .wy-form-halves input[type="date"],.wy-control-group .wy-form-halves input[type="month"],.wy-control-group .wy-form-halves input[type="time"],.wy-control-group .wy-form-halves input[type="datetime"],.wy-control-group .wy-form-halves input[type="datetime-local"],.wy-control-group .wy-form-halves input[type="week"],.wy-control-group .wy-form-halves input[type="number"],.wy-control-group .wy-form-halves input[type="search"],.wy-control-group .wy-form-halves input[type="tel"],.wy-control-group .wy-form-halves input[type="color"],.wy-control-group .wy-form-thirds input[type="text"],.wy-control-group .wy-form-thirds input[type="password"],.wy-control-group .wy-form-thirds input[type="email"],.wy-control-group .wy-form-thirds input[type="url"],.wy-control-group .wy-form-thirds input[type="date"],.wy-control-group .wy-form-thirds input[type="month"],.wy-control-group .wy-form-thirds input[type="time"],.wy-control-group .wy-form-thirds input[type="datetime"],.wy-control-group .wy-form-thirds input[type="datetime-local"],.wy-control-group .wy-form-thirds input[type="week"],.wy-control-group .wy-form-thirds input[type="number"],.wy-control-group .wy-form-thirds input[type="search"],.wy-control-group .wy-form-thirds input[type="tel"],.wy-control-group .wy-form-thirds input[type="color"]{width:100%}.wy-control-group .wy-form-full{float:left;display:block;margin-right:2.3576515979%;width:100%;margin-right:0}.wy-control-group .wy-form-full:last-child{margin-right:0}.wy-control-group .wy-form-halves{float:left;display:block;margin-right:2.3576515979%;width:48.821174201%}.wy-control-group .wy-form-halves:last-child{margin-right:0}.wy-control-group .wy-form-halves:nth-of-type(2n){margin-right:0}.wy-control-group .wy-form-halves:nth-of-type(2n+1){clear:left}.wy-control-group .wy-form-thirds{float:left;display:block;margin-right:2.3576515979%;width:31.7615656014%}.wy-control-group .wy-form-thirds:last-child{margin-right:0}.wy-control-group .wy-form-thirds:nth-of-type(3n){margin-right:0}.wy-control-group .wy-form-thirds:nth-of-type(3n+1){clear:left}.wy-control-group.wy-control-group-no-input .wy-control{margin:6px 0 0 0;font-size:90%}.wy-control-no-input{display:inline-block;margin:6px 0 0 0;font-size:90%}.wy-control-group.fluid-input input[type="text"],.wy-control-group.fluid-input input[type="password"],.wy-control-group.fluid-input input[type="email"],.wy-control-group.fluid-input input[type="url"],.wy-control-group.fluid-input input[type="date"],.wy-control-group.fluid-input input[type="month"],.wy-control-group.fluid-input input[type="time"],.wy-control-group.fluid-input input[type="datetime"],.wy-control-group.fluid-input input[type="datetime-local"],.wy-control-group.fluid-input input[type="week"],.wy-control-group.fluid-input input[type="number"],.wy-control-group.fluid-input input[type="search"],.wy-control-group.fluid-input input[type="tel"],.wy-control-group.fluid-input input[type="color"]{width:100%}.wy-form-message-inline{display:inline-block;padding-left:.3em;color:#666;vertical-align:middle;font-size:90%}.wy-form-message{display:block;color:#999;font-size:70%;margin-top:.3125em;font-style:italic}.wy-form-message p{font-size:inherit;font-style:italic;margin-bottom:6px}.wy-form-message p:last-child{margin-bottom:0}input{line-height:normal}input[type="button"],input[type="reset"],input[type="submit"]{-webkit-appearance:button;cursor:pointer;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;*overflow:visible}input[type="text"],input[type="password"],input[type="email"],input[type="url"],input[type="date"],input[type="month"],input[type="time"],input[type="datetime"],input[type="datetime-local"],input[type="week"],input[type="number"],input[type="search"],input[type="tel"],input[type="color"]{-webkit-appearance:none;padding:6px;display:inline-block;border:1px solid #ccc;font-size:80%;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;box-shadow:inset 0 1px 3px #ddd;border-radius:0;-webkit-transition:border .3s linear;-moz-transition:border .3s linear;transition:border .3s linear}input[type="datetime-local"]{padding:.34375em .625em}input[disabled]{cursor:default}input[type="checkbox"],input[type="radio"]{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;padding:0;margin-right:.3125em;*height:13px;*width:13px}input[type="search"]{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}input[type="search"]::-webkit-search-cancel-button,input[type="search"]::-webkit-search-decoration{-webkit-appearance:none}input[type="text"]:focus,input[type="password"]:focus,input[type="email"]:focus,input[type="url"]:focus,input[type="date"]:focus,input[type="month"]:focus,input[type="time"]:focus,input[type="datetime"]:focus,input[type="datetime-local"]:focus,input[type="week"]:focus,input[type="number"]:focus,input[type="search"]:focus,input[type="tel"]:focus,input[type="color"]:focus{outline:0;outline:thin dotted \9;border-color:#333}input.no-focus:focus{border-color:#ccc !important}input[type="file"]:focus,input[type="radio"]:focus,input[type="checkbox"]:focus{outline:thin dotted #333;outline:1px auto #129FEA}input[type="text"][disabled],input[type="password"][disabled],input[type="email"][disabled],input[type="url"][disabled],input[type="date"][disabled],input[type="month"][disabled],input[type="time"][disabled],input[type="datetime"][disabled],input[type="datetime-local"][disabled],input[type="week"][disabled],input[type="number"][disabled],input[type="search"][disabled],input[type="tel"][disabled],input[type="color"][disabled]{cursor:not-allowed;background-color:#fafafa}input:focus:invalid,textarea:focus:invalid,select:focus:invalid{color:#E74C3C;border:1px solid #E74C3C}input:focus:invalid:focus,textarea:focus:invalid:focus,select:focus:invalid:focus{border-color:#E74C3C}input[type="file"]:focus:invalid:focus,input[type="radio"]:focus:invalid:focus,input[type="checkbox"]:focus:invalid:focus{outline-color:#E74C3C}input.wy-input-large{padding:12px;font-size:100%}textarea{overflow:auto;vertical-align:top;width:100%;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif}select,textarea{padding:.5em .625em;display:inline-block;border:1px solid #ccc;font-size:80%;box-shadow:inset 0 1px 3px #ddd;-webkit-transition:border .3s linear;-moz-transition:border .3s linear;transition:border .3s linear}select{border:1px solid #ccc;background-color:#fff}select[multiple]{height:auto}select:focus,textarea:focus{outline:0}select[disabled],textarea[disabled],input[readonly],select[readonly],textarea[readonly]{cursor:not-allowed;background-color:#fafafa}input[type="radio"][disabled],input[type="checkbox"][disabled]{cursor:not-allowed}.wy-checkbox,.wy-radio{margin:6px 0;color:#404040;display:block}.wy-checkbox input,.wy-radio input{vertical-align:baseline}.wy-form-message-inline{display:inline-block;*display:inline;*zoom:1;vertical-align:middle}.wy-input-prefix,.wy-input-suffix{white-space:nowrap;padding:6px}.wy-input-prefix .wy-input-context,.wy-input-suffix .wy-input-context{line-height:27px;padding:0 8px;display:inline-block;font-size:80%;background-color:#f3f6f6;border:solid 1px #ccc;color:#999}.wy-input-suffix .wy-input-context{border-left:0}.wy-input-prefix .wy-input-context{border-right:0}.wy-switch{position:relative;display:block;height:24px;margin-top:12px;cursor:pointer}.wy-switch:before{position:absolute;content:"";display:block;left:0;top:0;width:36px;height:12px;border-radius:4px;background:#ccc;-webkit-transition:all .2s ease-in-out;-moz-transition:all .2s ease-in-out;transition:all .2s ease-in-out}.wy-switch:after{position:absolute;content:"";display:block;width:18px;height:18px;border-radius:4px;background:#999;left:-3px;top:-3px;-webkit-transition:all .2s ease-in-out;-moz-transition:all .2s ease-in-out;transition:all .2s ease-in-out}.wy-switch span{position:absolute;left:48px;display:block;font-size:12px;color:#ccc;line-height:1}.wy-switch.active:before{background:#1e8449}.wy-switch.active:after{left:24px;background:#27AE60}.wy-switch.disabled{cursor:not-allowed;opacity:.8}.wy-control-group.wy-control-group-error .wy-form-message,.wy-control-group.wy-control-group-error>label{color:#E74C3C}.wy-control-group.wy-control-group-error input[type="text"],.wy-control-group.wy-control-group-error input[type="password"],.wy-control-group.wy-control-group-error input[type="email"],.wy-control-group.wy-control-group-error input[type="url"],.wy-control-group.wy-control-group-error input[type="date"],.wy-control-group.wy-control-group-error input[type="month"],.wy-control-group.wy-control-group-error input[type="time"],.wy-control-group.wy-control-group-error input[type="datetime"],.wy-control-group.wy-control-group-error input[type="datetime-local"],.wy-control-group.wy-control-group-error input[type="week"],.wy-control-group.wy-control-group-error input[type="number"],.wy-control-group.wy-control-group-error input[type="search"],.wy-control-group.wy-control-group-error input[type="tel"],.wy-control-group.wy-control-group-error input[type="color"]{border:solid 1px #E74C3C}.wy-control-group.wy-control-group-error textarea{border:solid 1px #E74C3C}.wy-inline-validate{white-space:nowrap}.wy-inline-validate .wy-input-context{padding:.5em .625em;display:inline-block;font-size:80%}.wy-inline-validate.wy-inline-validate-success .wy-input-context{color:#27AE60}.wy-inline-validate.wy-inline-validate-danger .wy-input-context{color:#E74C3C}.wy-inline-validate.wy-inline-validate-warning .wy-input-context{color:#E67E22}.wy-inline-validate.wy-inline-validate-info .wy-input-context{color:#2980B9}.rotate-90{-webkit-transform:rotate(90deg);-moz-transform:rotate(90deg);-ms-transform:rotate(90deg);-o-transform:rotate(90deg);transform:rotate(90deg)}.rotate-180{-webkit-transform:rotate(180deg);-moz-transform:rotate(180deg);-ms-transform:rotate(180deg);-o-transform:rotate(180deg);transform:rotate(180deg)}.rotate-270{-webkit-transform:rotate(270deg);-moz-transform:rotate(270deg);-ms-transform:rotate(270deg);-o-transform:rotate(270deg);transform:rotate(270deg)}.mirror{-webkit-transform:scaleX(-1);-moz-transform:scaleX(-1);-ms-transform:scaleX(-1);-o-transform:scaleX(-1);transform:scaleX(-1)}.mirror.rotate-90{-webkit-transform:scaleX(-1) rotate(90deg);-moz-transform:scaleX(-1) rotate(90deg);-ms-transform:scaleX(-1) rotate(90deg);-o-transform:scaleX(-1) rotate(90deg);transform:scaleX(-1) rotate(90deg)}.mirror.rotate-180{-webkit-transform:scaleX(-1) rotate(180deg);-moz-transform:scaleX(-1) rotate(180deg);-ms-transform:scaleX(-1) rotate(180deg);-o-transform:scaleX(-1) rotate(180deg);transform:scaleX(-1) rotate(180deg)}.mirror.rotate-270{-webkit-transform:scaleX(-1) rotate(270deg);-moz-transform:scaleX(-1) rotate(270deg);-ms-transform:scaleX(-1) rotate(270deg);-o-transform:scaleX(-1) rotate(270deg);transform:scaleX(-1) rotate(270deg)}@media only screen and (max-width: 480px){.wy-form button[type="submit"]{margin:.7em 0 0}.wy-form input[type="text"],.wy-form input[type="password"],.wy-form input[type="email"],.wy-form input[type="url"],.wy-form input[type="date"],.wy-form input[type="month"],.wy-form input[type="time"],.wy-form input[type="datetime"],.wy-form input[type="datetime-local"],.wy-form input[type="week"],.wy-form input[type="number"],.wy-form input[type="search"],.wy-form input[type="tel"],.wy-form input[type="color"]{margin-bottom:.3em;display:block}.wy-form label{margin-bottom:.3em;display:block}.wy-form input[type="password"],.wy-form input[type="email"],.wy-form input[type="url"],.wy-form input[type="date"],.wy-form input[type="month"],.wy-form input[type="time"],.wy-form input[type="datetime"],.wy-form input[type="datetime-local"],.wy-form input[type="week"],.wy-form input[type="number"],.wy-form input[type="search"],.wy-form input[type="tel"],.wy-form input[type="color"]{margin-bottom:0}.wy-form-aligned .wy-control-group label{margin-bottom:.3em;text-align:left;display:block;width:100%}.wy-form-aligned .wy-control{margin:1.5em 0 0 0}.wy-form .wy-help-inline,.wy-form-message-inline,.wy-form-message{display:block;font-size:80%;padding:6px 0}}@media screen and (max-width: 768px){.tablet-hide{display:none}}@media screen and (max-width: 480px){.mobile-hide{display:none}}.float-left{float:left}.float-right{float:right}.full-width{width:100%}.wy-table,.rst-content table.docutils,.rst-content table.field-list{border-collapse:collapse;border-spacing:0;empty-cells:show;margin-bottom:24px}.wy-table caption,.rst-content table.docutils caption,.rst-content table.field-list caption{color:#000;font:italic 85%/1 arial,sans-serif;padding:1em 0;text-align:center}.wy-table td,.rst-content table.docutils td,.rst-content table.field-list td,.wy-table th,.rst-content table.docutils th,.rst-content table.field-list th{font-size:90%;margin:0;overflow:visible;padding:8px 16px}.wy-table td:first-child,.rst-content table.docutils td:first-child,.rst-content table.field-list td:first-child,.wy-table th:first-child,.rst-content table.docutils th:first-child,.rst-content table.field-list th:first-child{border-left-width:0}.wy-table thead,.rst-content table.docutils thead,.rst-content table.field-list thead{color:#000;text-align:left;vertical-align:bottom;white-space:nowrap}.wy-table thead th,.rst-content table.docutils thead th,.rst-content table.field-list thead th{font-weight:bold;border-bottom:solid 2px #e1e4e5}.wy-table td,.rst-content table.docutils td,.rst-content table.field-list td{background-color:transparent;vertical-align:middle}.wy-table td p,.rst-content table.docutils td p,.rst-content table.field-list td p{line-height:18px}.wy-table td p:last-child,.rst-content table.docutils td p:last-child,.rst-content table.field-list td p:last-child{margin-bottom:0}.wy-table .wy-table-cell-min,.rst-content table.docutils .wy-table-cell-min,.rst-content table.field-list .wy-table-cell-min{width:1%;padding-right:0}.wy-table .wy-table-cell-min input[type=checkbox],.rst-content table.docutils .wy-table-cell-min input[type=checkbox],.rst-content table.field-list .wy-table-cell-min input[type=checkbox],.wy-table .wy-table-cell-min input[type=checkbox],.rst-content table.docutils .wy-table-cell-min input[type=checkbox],.rst-content table.field-list .wy-table-cell-min input[type=checkbox]{margin:0}.wy-table-secondary{color:gray;font-size:90%}.wy-table-tertiary{color:gray;font-size:80%}.wy-table-odd td,.wy-table-striped tr:nth-child(2n-1) td,.rst-content table.docutils:not(.field-list) tr:nth-child(2n-1) td{background-color:#f3f6f6}.wy-table-backed{background-color:#f3f6f6}.wy-table-bordered-all,.rst-content table.docutils{border:1px solid #e1e4e5}.wy-table-bordered-all td,.rst-content table.docutils td{border-bottom:1px solid #e1e4e5;border-left:1px solid #e1e4e5}.wy-table-bordered-all tbody>tr:last-child td,.rst-content table.docutils tbody>tr:last-child td{border-bottom-width:0}.wy-table-bordered{border:1px solid #e1e4e5}.wy-table-bordered-rows td{border-bottom:1px solid #e1e4e5}.wy-table-bordered-rows tbody>tr:last-child td{border-bottom-width:0}.wy-table-horizontal tbody>tr:last-child td{border-bottom-width:0}.wy-table-horizontal td,.wy-table-horizontal th{border-width:0 0 1px 0;border-bottom:1px solid #e1e4e5}.wy-table-horizontal tbody>tr:last-child td{border-bottom-width:0}.wy-table-responsive{margin-bottom:24px;max-width:100%;overflow:auto}.wy-table-responsive table{margin-bottom:0 !important}.wy-table-responsive table td,.wy-table-responsive table th{white-space:nowrap}a{color:#2980B9;text-decoration:none;cursor:pointer}a:hover{color:#3091d1}a:visited{color:#9B59B6}html{height:100%;overflow-x:hidden}body{font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;font-weight:normal;color:#404040;min-height:100%;overflow-x:hidden;background:#edf0f2}.wy-text-left{text-align:left}.wy-text-center{text-align:center}.wy-text-right{text-align:right}.wy-text-large{font-size:120%}.wy-text-normal{font-size:100%}.wy-text-small,small{font-size:80%}.wy-text-strike{text-decoration:line-through}.wy-text-warning{color:#E67E22 !important}a.wy-text-warning:hover{color:#eb9950 !important}.wy-text-info{color:#2980B9 !important}a.wy-text-info:hover{color:#409ad5 !important}.wy-text-success{color:#27AE60 !important}a.wy-text-success:hover{color:#36d278 !important}.wy-text-danger{color:#E74C3C !important}a.wy-text-danger:hover{color:#ed7669 !important}.wy-text-neutral{color:#404040 !important}a.wy-text-neutral:hover{color:#595959 !important}h1,h2,.rst-content .toctree-wrapper p.caption,h3,h4,h5,h6,legend{margin-top:0;font-weight:700;font-family:"Roboto Slab","ff-tisa-web-pro","Georgia",Arial,sans-serif}p{line-height:24px;margin:0;font-size:16px;margin-bottom:24px}h1{font-size:175%}h2,.rst-content .toctree-wrapper p.caption{font-size:150%}h3{font-size:125%}h4{font-size:115%}h5{font-size:110%}h6{font-size:100%}hr{display:block;height:1px;border:0;border-top:1px solid #e1e4e5;margin:24px 0;padding:0}code,.rst-content tt,.rst-content code{white-space:nowrap;max-width:100%;background:#fff;border:solid 1px #e1e4e5;font-size:75%;padding:0 5px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;color:#E74C3C;overflow-x:auto}code.code-large,.rst-content tt.code-large{font-size:90%}.wy-plain-list-disc,.rst-content .section ul,.rst-content .toctree-wrapper ul,article ul{list-style:disc;line-height:24px;margin-bottom:24px}.wy-plain-list-disc li,.rst-content .section ul li,.rst-content .toctree-wrapper ul li,article ul li{list-style:disc;margin-left:24px}.wy-plain-list-disc li p:last-child,.rst-content .section ul li p:last-child,.rst-content .toctree-wrapper ul li p:last-child,article ul li p:last-child{margin-bottom:0}.wy-plain-list-disc li ul,.rst-content .section ul li ul,.rst-content .toctree-wrapper ul li ul,article ul li ul{margin-bottom:0}.wy-plain-list-disc li li,.rst-content .section ul li li,.rst-content .toctree-wrapper ul li li,article ul li li{list-style:circle}.wy-plain-list-disc li li li,.rst-content .section ul li li li,.rst-content .toctree-wrapper ul li li li,article ul li li li{list-style:square}.wy-plain-list-disc li ol li,.rst-content .section ul li ol li,.rst-content .toctree-wrapper ul li ol li,article ul li ol li{list-style:decimal}.wy-plain-list-decimal,.rst-content .section ol,.rst-content ol.arabic,article ol{list-style:decimal;line-height:24px;margin-bottom:24px}.wy-plain-list-decimal li,.rst-content .section ol li,.rst-content ol.arabic li,article ol li{list-style:decimal;margin-left:24px}.wy-plain-list-decimal li p:last-child,.rst-content .section ol li p:last-child,.rst-content ol.arabic li p:last-child,article ol li p:last-child{margin-bottom:0}.wy-plain-list-decimal li ul,.rst-content .section ol li ul,.rst-content ol.arabic li ul,article ol li ul{margin-bottom:0}.wy-plain-list-decimal li ul li,.rst-content .section ol li ul li,.rst-content ol.arabic li ul li,article ol li ul li{list-style:disc}.wy-breadcrumbs{*zoom:1}.wy-breadcrumbs:before,.wy-breadcrumbs:after{display:table;content:""}.wy-breadcrumbs:after{clear:both}.wy-breadcrumbs li{display:inline-block}.wy-breadcrumbs li.wy-breadcrumbs-aside{float:right}.wy-breadcrumbs li a{display:inline-block;padding:5px}.wy-breadcrumbs li a:first-child{padding-left:0}.wy-breadcrumbs li code,.wy-breadcrumbs li .rst-content tt,.rst-content .wy-breadcrumbs li tt{padding:5px;border:none;background:none}.wy-breadcrumbs li code.literal,.wy-breadcrumbs li .rst-content tt.literal,.rst-content .wy-breadcrumbs li tt.literal{color:#404040}.wy-breadcrumbs-extra{margin-bottom:0;color:#b3b3b3;font-size:80%;display:inline-block}@media screen and (max-width: 480px){.wy-breadcrumbs-extra{display:none}.wy-breadcrumbs li.wy-breadcrumbs-aside{display:none}}@media print{.wy-breadcrumbs li.wy-breadcrumbs-aside{display:none}}html{font-size:16px}.wy-affix{position:fixed;top:1.618em}.wy-menu a:hover{text-decoration:none}.wy-menu-horiz{*zoom:1}.wy-menu-horiz:before,.wy-menu-horiz:after{display:table;content:""}.wy-menu-horiz:after{clear:both}.wy-menu-horiz ul,.wy-menu-horiz li{display:inline-block}.wy-menu-horiz li:hover{background:rgba(255,255,255,0.1)}.wy-menu-horiz li.divide-left{border-left:solid 1px #404040}.wy-menu-horiz li.divide-right{border-right:solid 1px #404040}.wy-menu-horiz a{height:32px;display:inline-block;line-height:32px;padding:0 16px}.wy-menu-vertical{width:300px}.wy-menu-vertical header,.wy-menu-vertical p.caption{color:#3a7ca8;height:32px;display:inline-block;line-height:32px;padding:0 1.618em;margin:12px 0 0 0;display:block;font-weight:bold;text-transform:uppercase;font-size:85%;white-space:nowrap}.wy-menu-vertical ul{margin-bottom:0}.wy-menu-vertical li.divide-top{border-top:solid 1px #404040}.wy-menu-vertical li.divide-bottom{border-bottom:solid 1px #404040}.wy-menu-vertical li.current{background:#e3e3e3}.wy-menu-vertical li.current a{color:gray;border-right:solid 1px #c9c9c9;padding:.4045em 2.427em}.wy-menu-vertical li.current a:hover{background:#d6d6d6}.wy-menu-vertical li code,.wy-menu-vertical li .rst-content tt,.rst-content .wy-menu-vertical li tt{border:none;background:inherit;color:inherit;padding-left:0;padding-right:0}.wy-menu-vertical li span.toctree-expand{display:block;float:left;margin-left:-1.2em;font-size:.8em;line-height:1.6em;color:#4d4d4d}.wy-menu-vertical li.on a,.wy-menu-vertical li.current>a{color:#404040;padding:.4045em 1.618em;font-weight:bold;position:relative;background:#fcfcfc;border:none;padding-left:1.618em -4px}.wy-menu-vertical li.on a:hover,.wy-menu-vertical li.current>a:hover{background:#fcfcfc}.wy-menu-vertical li.on a:hover span.toctree-expand,.wy-menu-vertical li.current>a:hover span.toctree-expand{color:gray}.wy-menu-vertical li.on a span.toctree-expand,.wy-menu-vertical li.current>a span.toctree-expand{display:block;font-size:.8em;line-height:1.6em;color:#333}.wy-menu-vertical li.toctree-l1.current>a{border-bottom:solid 1px #c9c9c9;border-top:solid 1px #c9c9c9}.wy-menu-vertical li.toctree-l2 a,.wy-menu-vertical li.toctree-l3 a,.wy-menu-vertical li.toctree-l4 a{color:#404040}.wy-menu-vertical li.toctree-l1.current li.toctree-l2>ul,.wy-menu-vertical li.toctree-l2.current li.toctree-l3>ul{display:none}.wy-menu-vertical li.toctree-l1.current li.toctree-l2.current>ul,.wy-menu-vertical li.toctree-l2.current li.toctree-l3.current>ul{display:block}.wy-menu-vertical li.toctree-l2.current>a{background:#c9c9c9;padding:.4045em 2.427em}.wy-menu-vertical li.toctree-l2.current li.toctree-l3>a{display:block;background:#c9c9c9;padding:.4045em 4.045em}.wy-menu-vertical li.toctree-l2 a:hover span.toctree-expand{color:gray}.wy-menu-vertical li.toctree-l2 span.toctree-expand{color:#a3a3a3}.wy-menu-vertical li.toctree-l3{font-size:.9em}.wy-menu-vertical li.toctree-l3.current>a{background:#bdbdbd;padding:.4045em 4.045em}.wy-menu-vertical li.toctree-l3.current li.toctree-l4>a{display:block;background:#bdbdbd;padding:.4045em 5.663em}.wy-menu-vertical li.toctree-l3 a:hover span.toctree-expand{color:gray}.wy-menu-vertical li.toctree-l3 span.toctree-expand{color:#969696}.wy-menu-vertical li.toctree-l4{font-size:.9em}.wy-menu-vertical li.current ul{display:block}.wy-menu-vertical li ul{margin-bottom:0;display:none}.wy-menu-vertical li ul li a{margin-bottom:0;color:#d9d9d9;font-weight:normal}.wy-menu-vertical a{display:inline-block;line-height:18px;padding:.4045em 1.618em;display:block;position:relative;font-size:90%;color:#d9d9d9}.wy-menu-vertical a:hover{background-color:#4e4a4a;cursor:pointer}.wy-menu-vertical a:hover span.toctree-expand{color:#d9d9d9}.wy-menu-vertical a:active{background-color:#2980B9;cursor:pointer;color:#fff}.wy-menu-vertical a:active span.toctree-expand{color:#fff}.wy-side-nav-search{display:block;width:300px;padding:.809em;margin-bottom:.809em;z-index:200;background-color:#2980B9;text-align:center;padding:.809em;display:block;color:#fcfcfc;margin-bottom:.809em}.wy-side-nav-search input[type=text]{width:100%;border-radius:50px;padding:6px 12px;border-color:#2472a4}.wy-side-nav-search img{display:block;margin:auto auto .809em auto;height:45px;width:45px;background-color:#2980B9;padding:5px;border-radius:100%}.wy-side-nav-search>a,.wy-side-nav-search .wy-dropdown>a{color:#fcfcfc;font-size:100%;font-weight:bold;display:inline-block;padding:4px 6px;margin-bottom:.809em}.wy-side-nav-search>a:hover,.wy-side-nav-search .wy-dropdown>a:hover{background:rgba(255,255,255,0.1)}.wy-side-nav-search>a img.logo,.wy-side-nav-search .wy-dropdown>a img.logo{display:block;margin:0 auto;height:auto;width:auto;border-radius:0;max-width:100%;background:transparent}.wy-side-nav-search>a.icon img.logo,.wy-side-nav-search .wy-dropdown>a.icon img.logo{margin-top:.85em}.wy-side-nav-search>div.version{margin-top:-.4045em;margin-bottom:.809em;font-weight:normal;color:rgba(255,255,255,0.3)}.wy-nav .wy-menu-vertical header{color:#2980B9}.wy-nav .wy-menu-vertical a{color:#b3b3b3}.wy-nav .wy-menu-vertical a:hover{background-color:#2980B9;color:#fff}[data-menu-wrap]{-webkit-transition:all .2s ease-in;-moz-transition:all .2s ease-in;transition:all .2s ease-in;position:absolute;opacity:1;width:100%;opacity:0}[data-menu-wrap].move-center{left:0;right:auto;opacity:1}[data-menu-wrap].move-left{right:auto;left:-100%;opacity:0}[data-menu-wrap].move-right{right:-100%;left:auto;opacity:0}.wy-body-for-nav{background:#fcfcfc}.wy-grid-for-nav{position:absolute;width:100%;height:100%}.wy-nav-side{position:fixed;top:0;bottom:0;left:0;padding-bottom:2em;width:300px;overflow-x:hidden;overflow-y:hidden;min-height:100%;color:#9b9b9b;background:#343131;z-index:200}.wy-side-scroll{width:320px;position:relative;overflow-x:hidden;overflow-y:scroll;height:100%}.wy-nav-top{display:none;background:#2980B9;color:#fff;padding:.4045em .809em;position:relative;line-height:50px;text-align:center;font-size:100%;*zoom:1}.wy-nav-top:before,.wy-nav-top:after{display:table;content:""}.wy-nav-top:after{clear:both}.wy-nav-top a{color:#fff;font-weight:bold}.wy-nav-top img{margin-right:12px;height:45px;width:45px;background-color:#2980B9;padding:5px;border-radius:100%}.wy-nav-top i{font-size:30px;float:left;cursor:pointer;padding-top:inherit}.wy-nav-content-wrap{margin-left:300px;background:#fcfcfc;min-height:100%}.wy-nav-content{padding:1.618em 3.236em;height:100%;max-width:800px;margin:auto}.wy-body-mask{position:fixed;width:100%;height:100%;background:rgba(0,0,0,0.2);display:none;z-index:499}.wy-body-mask.on{display:block}footer{color:gray}footer p{margin-bottom:12px}footer span.commit code,footer span.commit .rst-content tt,.rst-content footer span.commit tt{padding:0px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;font-size:1em;background:none;border:none;color:gray}.rst-footer-buttons{*zoom:1}.rst-footer-buttons:before,.rst-footer-buttons:after{width:100%}.rst-footer-buttons:before,.rst-footer-buttons:after{display:table;content:""}.rst-footer-buttons:after{clear:both}.rst-breadcrumbs-buttons{margin-top:12px;*zoom:1}.rst-breadcrumbs-buttons:before,.rst-breadcrumbs-buttons:after{display:table;content:""}.rst-breadcrumbs-buttons:after{clear:both}#search-results .search li{margin-bottom:24px;border-bottom:solid 1px #e1e4e5;padding-bottom:24px}#search-results .search li:first-child{border-top:solid 1px #e1e4e5;padding-top:24px}#search-results .search li a{font-size:120%;margin-bottom:12px;display:inline-block}#search-results .context{color:gray;font-size:90%}.genindextable li>ul{margin-left:24px}@media screen and (max-width: 768px){.wy-body-for-nav{background:#fcfcfc}.wy-nav-top{display:block}.wy-nav-side{left:-300px}.wy-nav-side.shift{width:85%;left:0}.wy-side-scroll{width:auto}.wy-side-nav-search{width:auto}.wy-menu.wy-menu-vertical{width:auto}.wy-nav-content-wrap{margin-left:0}.wy-nav-content-wrap .wy-nav-content{padding:1.618em}.wy-nav-content-wrap.shift{position:fixed;min-width:100%;left:85%;top:0;height:100%;overflow:hidden}}@media screen and (min-width: 1100px){.wy-nav-content-wrap{background:rgba(0,0,0,0.05)}.wy-nav-content{margin:0;background:#fcfcfc}}@media print{.rst-versions,footer,.wy-nav-side{display:none}.wy-nav-content-wrap{margin-left:0}}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa,.rst-versions .rst-current-version .wy-menu-vertical li span.toctree-expand,.wy-menu-vertical li .rst-versions .rst-current-version span.toctree-expand,.rst-versions .rst-current-version .rst-content .admonition-title,.rst-content .rst-versions .rst-current-version .admonition-title,.rst-versions .rst-current-version .rst-content h1 .headerlink,.rst-content h1 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h2 .headerlink,.rst-content h2 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h3 .headerlink,.rst-content h3 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h4 .headerlink,.rst-content h4 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h5 .headerlink,.rst-content h5 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content h6 .headerlink,.rst-content h6 .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content dl dt .headerlink,.rst-content dl dt .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content p.caption .headerlink,.rst-content p.caption .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content table>caption .headerlink,.rst-content table>caption .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content .code-block-caption .headerlink,.rst-content .code-block-caption .rst-versions .rst-current-version .headerlink,.rst-versions .rst-current-version .rst-content tt.download span:first-child,.rst-content tt.download .rst-versions .rst-current-version span:first-child,.rst-versions .rst-current-version .rst-content code.download span:first-child,.rst-content code.download .rst-versions .rst-current-version span:first-child,.rst-versions .rst-current-version .icon{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}}.rst-content img{max-width:100%;height:auto}.rst-content div.figure{margin-bottom:24px}.rst-content div.figure p.caption{font-style:italic}.rst-content div.figure p:last-child.caption{margin-bottom:0px}.rst-content div.figure.align-center{text-align:center}.rst-content .section>img,.rst-content .section>a>img{margin-bottom:24px}.rst-content abbr[title]{text-decoration:none}.rst-content.style-external-links a.reference.external:after{font-family:FontAwesome;content:"";color:#b3b3b3;vertical-align:super;font-size:60%;margin:0 .2em}.rst-content blockquote{margin-left:24px;line-height:24px;margin-bottom:24px}.rst-content pre.literal-block{white-space:pre;margin:0;padding:12px 12px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;display:block;overflow:auto}.rst-content pre.literal-block,.rst-content div[class^='highlight']{border:1px solid #e1e4e5;overflow-x:auto;margin:1px 0 24px 0}.rst-content pre.literal-block div[class^='highlight'],.rst-content div[class^='highlight'] div[class^='highlight']{padding:0px;border:none;margin:0}.rst-content div[class^='highlight'] td.code{width:100%}.rst-content .linenodiv pre{border-right:solid 1px #e6e9ea;margin:0;padding:12px 12px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;user-select:none;pointer-events:none}.rst-content div[class^='highlight'] pre{white-space:pre;margin:0;padding:12px 12px;display:block;overflow:auto}.rst-content div[class^='highlight'] pre .hll{display:block;margin:0 -12px;padding:0 12px}.rst-content pre.literal-block,.rst-content div[class^='highlight'] pre,.rst-content .linenodiv pre{font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;font-size:12px;line-height:1.4}.rst-content .code-block-caption{font-style:italic;font-size:85%;line-height:1;padding:1em 0;text-align:center}@media print{.rst-content .codeblock,.rst-content div[class^='highlight'],.rst-content div[class^='highlight'] pre{white-space:pre-wrap}}.rst-content .note .last,.rst-content .attention .last,.rst-content .caution .last,.rst-content .danger .last,.rst-content .error .last,.rst-content .hint .last,.rst-content .important .last,.rst-content .tip .last,.rst-content .warning .last,.rst-content .seealso .last,.rst-content .admonition-todo .last,.rst-content .admonition .last{margin-bottom:0}.rst-content .admonition-title:before{margin-right:4px}.rst-content .admonition table{border-color:rgba(0,0,0,0.1)}.rst-content .admonition table td,.rst-content .admonition table th{background:transparent !important;border-color:rgba(0,0,0,0.1) !important}.rst-content .section ol.loweralpha,.rst-content .section ol.loweralpha li{list-style:lower-alpha}.rst-content .section ol.upperalpha,.rst-content .section ol.upperalpha li{list-style:upper-alpha}.rst-content .section ol p,.rst-content .section ul p{margin-bottom:12px}.rst-content .section ol p:last-child,.rst-content .section ul p:last-child{margin-bottom:24px}.rst-content .line-block{margin-left:0px;margin-bottom:24px;line-height:24px}.rst-content .line-block .line-block{margin-left:24px;margin-bottom:0px}.rst-content .topic-title{font-weight:bold;margin-bottom:12px}.rst-content .toc-backref{color:#404040}.rst-content .align-right{float:right;margin:0px 0px 24px 24px}.rst-content .align-left{float:left;margin:0px 24px 24px 0px}.rst-content .align-center{margin:auto}.rst-content .align-center:not(table){display:block}.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content .toctree-wrapper p.caption .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content dl dt .headerlink,.rst-content p.caption .headerlink,.rst-content table>caption .headerlink,.rst-content .code-block-caption .headerlink{visibility:hidden;font-size:14px}.rst-content h1 .headerlink:after,.rst-content h2 .headerlink:after,.rst-content .toctree-wrapper p.caption .headerlink:after,.rst-content h3 .headerlink:after,.rst-content h4 .headerlink:after,.rst-content h5 .headerlink:after,.rst-content h6 .headerlink:after,.rst-content dl dt .headerlink:after,.rst-content p.caption .headerlink:after,.rst-content table>caption .headerlink:after,.rst-content .code-block-caption .headerlink:after{content:"";font-family:FontAwesome}.rst-content h1:hover .headerlink:after,.rst-content h2:hover .headerlink:after,.rst-content .toctree-wrapper p.caption:hover .headerlink:after,.rst-content h3:hover .headerlink:after,.rst-content h4:hover .headerlink:after,.rst-content h5:hover .headerlink:after,.rst-content h6:hover .headerlink:after,.rst-content dl dt:hover .headerlink:after,.rst-content p.caption:hover .headerlink:after,.rst-content table>caption:hover .headerlink:after,.rst-content .code-block-caption:hover .headerlink:after{visibility:visible}.rst-content table>caption .headerlink:after{font-size:12px}.rst-content .centered{text-align:center}.rst-content .sidebar{float:right;width:40%;display:block;margin:0 0 24px 24px;padding:24px;background:#f3f6f6;border:solid 1px #e1e4e5}.rst-content .sidebar p,.rst-content .sidebar ul,.rst-content .sidebar dl{font-size:90%}.rst-content .sidebar .last{margin-bottom:0}.rst-content .sidebar .sidebar-title{display:block;font-family:"Roboto Slab","ff-tisa-web-pro","Georgia",Arial,sans-serif;font-weight:bold;background:#e1e4e5;padding:6px 12px;margin:-24px;margin-bottom:24px;font-size:100%}.rst-content .highlighted{background:#F1C40F;display:inline-block;font-weight:bold;padding:0 6px}.rst-content .footnote-reference,.rst-content .citation-reference{vertical-align:baseline;position:relative;top:-0.4em;line-height:0;font-size:90%}.rst-content table.docutils.citation,.rst-content table.docutils.footnote{background:none;border:none;color:gray}.rst-content table.docutils.citation td,.rst-content table.docutils.citation tr,.rst-content table.docutils.footnote td,.rst-content table.docutils.footnote tr{border:none;background-color:transparent !important;white-space:normal}.rst-content table.docutils.citation td.label,.rst-content table.docutils.footnote td.label{padding-left:0;padding-right:0;vertical-align:top}.rst-content table.docutils.citation tt,.rst-content table.docutils.citation code,.rst-content table.docutils.footnote tt,.rst-content table.docutils.footnote code{color:#555}.rst-content .wy-table-responsive.citation,.rst-content .wy-table-responsive.footnote{margin-bottom:0}.rst-content .wy-table-responsive.citation+:not(.citation),.rst-content .wy-table-responsive.footnote+:not(.footnote){margin-top:24px}.rst-content .wy-table-responsive.citation:last-child,.rst-content .wy-table-responsive.footnote:last-child{margin-bottom:24px}.rst-content table.docutils th{border-color:#e1e4e5}.rst-content table.docutils td .last,.rst-content table.docutils td .last :last-child{margin-bottom:0}.rst-content table.field-list{border:none}.rst-content table.field-list td{border:none}.rst-content table.field-list td p{font-size:inherit;line-height:inherit}.rst-content table.field-list td>strong{display:inline-block}.rst-content table.field-list .field-name{padding-right:10px;text-align:left;white-space:nowrap}.rst-content table.field-list .field-body{text-align:left}.rst-content tt,.rst-content tt,.rst-content code{color:#000;font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace;padding:2px 5px}.rst-content tt big,.rst-content tt em,.rst-content tt big,.rst-content code big,.rst-content tt em,.rst-content code em{font-size:100% !important;line-height:normal}.rst-content tt.literal,.rst-content tt.literal,.rst-content code.literal{color:#E74C3C}.rst-content tt.xref,a .rst-content tt,.rst-content tt.xref,.rst-content code.xref,a .rst-content tt,a .rst-content code{font-weight:bold;color:#404040}.rst-content pre,.rst-content kbd,.rst-content samp{font-family:SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace}.rst-content a tt,.rst-content a tt,.rst-content a code{color:#2980B9}.rst-content dl{margin-bottom:24px}.rst-content dl dt{font-weight:bold;margin-bottom:12px}.rst-content dl p,.rst-content dl table,.rst-content dl ul,.rst-content dl ol{margin-bottom:12px !important}.rst-content dl dd{margin:0 0 12px 24px;line-height:24px}.rst-content dl:not(.docutils){margin-bottom:24px}.rst-content dl:not(.docutils) dt{display:table;margin:6px 0;font-size:90%;line-height:normal;background:#e7f2fa;color:#2980B9;border-top:solid 3px #6ab0de;padding:6px;position:relative}.rst-content dl:not(.docutils) dt:before{color:#6ab0de}.rst-content dl:not(.docutils) dt .headerlink{color:#404040;font-size:100% !important}.rst-content dl:not(.docutils) dl dt{margin-bottom:6px;border:none;border-left:solid 3px #ccc;background:#f0f0f0;color:#555}.rst-content dl:not(.docutils) dl dt .headerlink{color:#404040;font-size:100% !important}.rst-content dl:not(.docutils) dt:first-child{margin-top:0}.rst-content dl:not(.docutils) tt,.rst-content dl:not(.docutils) tt,.rst-content dl:not(.docutils) code{font-weight:bold}.rst-content dl:not(.docutils) tt.descname,.rst-content dl:not(.docutils) tt.descclassname,.rst-content dl:not(.docutils) tt.descname,.rst-content dl:not(.docutils) code.descname,.rst-content dl:not(.docutils) tt.descclassname,.rst-content dl:not(.docutils) code.descclassname{background-color:transparent;border:none;padding:0;font-size:100% !important}.rst-content dl:not(.docutils) tt.descname,.rst-content dl:not(.docutils) tt.descname,.rst-content dl:not(.docutils) code.descname{font-weight:bold}.rst-content dl:not(.docutils) .optional{display:inline-block;padding:0 4px;color:#000;font-weight:bold}.rst-content dl:not(.docutils) .property{display:inline-block;padding-right:8px}.rst-content .viewcode-link,.rst-content .viewcode-back{display:inline-block;color:#27AE60;font-size:80%;padding-left:24px}.rst-content .viewcode-back{display:block;float:right}.rst-content p.rubric{margin-bottom:12px;font-weight:bold}.rst-content tt.download,.rst-content code.download{background:inherit;padding:inherit;font-weight:normal;font-family:inherit;font-size:inherit;color:inherit;border:inherit;white-space:inherit}.rst-content tt.download span:first-child,.rst-content code.download span:first-child{-webkit-font-smoothing:subpixel-antialiased}.rst-content tt.download span:first-child:before,.rst-content code.download span:first-child:before{margin-right:4px}.rst-content .guilabel{border:1px solid #7fbbe3;background:#e7f2fa;font-size:80%;font-weight:700;border-radius:4px;padding:2.4px 6px;margin:auto 2px}.rst-content .versionmodified{font-style:italic}@media screen and (max-width: 480px){.rst-content .sidebar{width:100%}}span[id*='MathJax-Span']{color:#404040}.math{text-align:center}@font-face{font-family:"Lato";src:url("../fonts/Lato/lato-regular.eot");src:url("../fonts/Lato/lato-regular.eot?#iefix") format("embedded-opentype"),url("../fonts/Lato/lato-regular.woff2") format("woff2"),url("../fonts/Lato/lato-regular.woff") format("woff"),url("../fonts/Lato/lato-regular.ttf") format("truetype");font-weight:400;font-style:normal}@font-face{font-family:"Lato";src:url("../fonts/Lato/lato-bold.eot");src:url("../fonts/Lato/lato-bold.eot?#iefix") format("embedded-opentype"),url("../fonts/Lato/lato-bold.woff2") format("woff2"),url("../fonts/Lato/lato-bold.woff") format("woff"),url("../fonts/Lato/lato-bold.ttf") format("truetype");font-weight:700;font-style:normal}@font-face{font-family:"Lato";src:url("../fonts/Lato/lato-bolditalic.eot");src:url("../fonts/Lato/lato-bolditalic.eot?#iefix") format("embedded-opentype"),url("../fonts/Lato/lato-bolditalic.woff2") format("woff2"),url("../fonts/Lato/lato-bolditalic.woff") format("woff"),url("../fonts/Lato/lato-bolditalic.ttf") format("truetype");font-weight:700;font-style:italic}@font-face{font-family:"Lato";src:url("../fonts/Lato/lato-italic.eot");src:url("../fonts/Lato/lato-italic.eot?#iefix") format("embedded-opentype"),url("../fonts/Lato/lato-italic.woff2") format("woff2"),url("../fonts/Lato/lato-italic.woff") format("woff"),url("../fonts/Lato/lato-italic.ttf") format("truetype");font-weight:400;font-style:italic}@font-face{font-family:"Roboto Slab";font-style:normal;font-weight:400;src:url("../fonts/RobotoSlab/roboto-slab.eot");src:url("../fonts/RobotoSlab/roboto-slab-v7-regular.eot?#iefix") format("embedded-opentype"),url("../fonts/RobotoSlab/roboto-slab-v7-regular.woff2") format("woff2"),url("../fonts/RobotoSlab/roboto-slab-v7-regular.woff") format("woff"),url("../fonts/RobotoSlab/roboto-slab-v7-regular.ttf") format("truetype")}@font-face{font-family:"Roboto Slab";font-style:normal;font-weight:700;src:url("../fonts/RobotoSlab/roboto-slab-v7-bold.eot");src:url("../fonts/RobotoSlab/roboto-slab-v7-bold.eot?#iefix") format("embedded-opentype"),url("../fonts/RobotoSlab/roboto-slab-v7-bold.woff2") format("woff2"),url("../fonts/RobotoSlab/roboto-slab-v7-bold.woff") format("woff"),url("../fonts/RobotoSlab/roboto-slab-v7-bold.ttf") format("truetype")} diff --git a/docs/build/html/_static/custom.css b/docs/build/html/_static/custom.css new file mode 100644 index 0000000..2a924f1 --- /dev/null +++ b/docs/build/html/_static/custom.css @@ -0,0 +1 @@ +/* This file intentionally left blank. */ diff --git a/docs/build/html/_static/doctools.js b/docs/build/html/_static/doctools.js index ffadbec..344db17 100644 --- a/docs/build/html/_static/doctools.js +++ b/docs/build/html/_static/doctools.js @@ -4,7 +4,7 @@ * * Sphinx JavaScript utilities for all documentation. * - * :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index 168d437..bee34fe 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,296 +1,10 @@ var DOCUMENTATION_OPTIONS = { URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), - VERSION: '', + VERSION: '0.1', LANGUAGE: 'None', COLLAPSE_INDEX: false, FILE_SUFFIX: '.html', HAS_SOURCE: true, SOURCELINK_SUFFIX: '.txt', NAVIGATION_WITH_KEYS: false, - SEARCH_LANGUAGE_STOP_WORDS: ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"] -}; - - - -/* Non-minified version JS is _stemmer.js if file is provided */ -/** - * Porter Stemmer - */ -var Stemmer = function() { - - var step2list = { - ational: 'ate', - tional: 'tion', - enci: 'ence', - anci: 'ance', - izer: 'ize', - bli: 'ble', - alli: 'al', - entli: 'ent', - eli: 'e', - ousli: 'ous', - ization: 'ize', - ation: 'ate', - ator: 'ate', - alism: 'al', - iveness: 'ive', - fulness: 'ful', - ousness: 'ous', - aliti: 'al', - iviti: 'ive', - biliti: 'ble', - logi: 'log' - }; - - var step3list = { - icate: 'ic', - ative: '', - alize: 'al', - iciti: 'ic', - ical: 'ic', - ful: '', - ness: '' - }; - - var c = "[^aeiou]"; // consonant - var v = "[aeiouy]"; // vowel - var C = c + "[^aeiouy]*"; // consonant sequence - var V = v + "[aeiou]*"; // vowel sequence - - var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0 - var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 - var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 - var s_v = "^(" + C + ")?" + v; // vowel in stem - - this.stemWord = function (w) { - var stem; - var suffix; - var firstch; - var origword = w; - - if (w.length < 3) - return w; - - var re; - var re2; - var re3; - var re4; - - firstch = w.substr(0,1); - if (firstch == "y") - w = firstch.toUpperCase() + w.substr(1); - - // Step 1a - re = /^(.+?)(ss|i)es$/; - re2 = /^(.+?)([^s])s$/; - - if (re.test(w)) - w = w.replace(re,"$1$2"); - else if (re2.test(w)) - w = w.replace(re2,"$1$2"); - - // Step 1b - re = /^(.+?)eed$/; - re2 = /^(.+?)(ed|ing)$/; - if (re.test(w)) { - var fp = re.exec(w); - re = new RegExp(mgr0); - if (re.test(fp[1])) { - re = /.$/; - w = w.replace(re,""); - } - } - else if (re2.test(w)) { - var fp = re2.exec(w); - stem = fp[1]; - re2 = new RegExp(s_v); - if (re2.test(stem)) { - w = stem; - re2 = /(at|bl|iz)$/; - re3 = new RegExp("([^aeiouylsz])\\1$"); - re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); - if (re2.test(w)) - w = w + "e"; - else if (re3.test(w)) { - re = /.$/; - w = w.replace(re,""); - } - else if (re4.test(w)) - w = w + "e"; - } - } - - // Step 1c - re = /^(.+?)y$/; - if (re.test(w)) { - var fp = re.exec(w); - stem = fp[1]; - re = new RegExp(s_v); - if (re.test(stem)) - w = stem + "i"; - } - - // Step 2 - re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; - if (re.test(w)) { - var fp = re.exec(w); - stem = fp[1]; - suffix = fp[2]; - re = new RegExp(mgr0); - if (re.test(stem)) - w = stem + step2list[suffix]; - } - - // Step 3 - re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; - if (re.test(w)) { - var fp = re.exec(w); - stem = fp[1]; - suffix = fp[2]; - re = new RegExp(mgr0); - if (re.test(stem)) - w = stem + step3list[suffix]; - } - - // Step 4 - re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; - re2 = /^(.+?)(s|t)(ion)$/; - if (re.test(w)) { - var fp = re.exec(w); - stem = fp[1]; - re = new RegExp(mgr1); - if (re.test(stem)) - w = stem; - } - else if (re2.test(w)) { - var fp = re2.exec(w); - stem = fp[1] + fp[2]; - re2 = new RegExp(mgr1); - if (re2.test(stem)) - w = stem; - } - - // Step 5 - re = /^(.+?)e$/; - if (re.test(w)) { - var fp = re.exec(w); - stem = fp[1]; - re = new RegExp(mgr1); - re2 = new RegExp(meq1); - re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); - if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) - w = stem; - } - re = /ll$/; - re2 = new RegExp(mgr1); - if (re.test(w) && re2.test(w)) { - re = /.$/; - w = w.replace(re,""); - } - - // and turn initial Y back to y - if (firstch == "y") - w = firstch.toLowerCase() + w.substr(1); - return w; - } -} - - - - - -var splitChars = (function() { - var result = {}; - var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648, - 1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702, - 2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971, - 2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345, - 3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761, - 3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823, - 4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125, - 8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695, - 11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587, - 43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141]; - var i, j, start, end; - for (i = 0; i < singles.length; i++) { - result[singles[i]] = true; - } - var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709], - [722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161], - [1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568], - [1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807], - [1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047], - [2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383], - [2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450], - [2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547], - [2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673], - [2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820], - [2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946], - [2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023], - [3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173], - [3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332], - [3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481], - [3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718], - [3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791], - [3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095], - [4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205], - [4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687], - [4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968], - [4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869], - [5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102], - [6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271], - [6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592], - [6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822], - [6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167], - [7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959], - [7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143], - [8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318], - [8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483], - [8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101], - [10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567], - [11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292], - [12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444], - [12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783], - [12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311], - [19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511], - [42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774], - [42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071], - [43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263], - [43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519], - [43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647], - [43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967], - [44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295], - [57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274], - [64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007], - [65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381], - [65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]]; - for (i = 0; i < ranges.length; i++) { - start = ranges[i][0]; - end = ranges[i][1]; - for (j = start; j <= end; j++) { - result[j] = true; - } - } - return result; -})(); - -function splitQuery(query) { - var result = []; - var start = -1; - for (var i = 0; i < query.length; i++) { - if (splitChars[query.charCodeAt(i)]) { - if (start !== -1) { - result.push(query.slice(start, i)); - start = -1; - } - } else if (start === -1) { - start = i; - } - } - if (start !== -1) { - result.push(query.slice(start)); - } - return result; -} - - +}; \ No newline at end of file diff --git a/docs/build/html/_static/fonts/FontAwesome.otf b/docs/build/html/_static/fonts/FontAwesome.otf deleted file mode 100644 index 401ec0f..0000000 Binary files a/docs/build/html/_static/fonts/FontAwesome.otf and /dev/null differ diff --git a/docs/build/html/_static/fonts/Inconsolata-Bold.ttf b/docs/build/html/_static/fonts/Inconsolata-Bold.ttf new file mode 100644 index 0000000..809c1f5 Binary files /dev/null and b/docs/build/html/_static/fonts/Inconsolata-Bold.ttf differ diff --git a/docs/build/html/_static/fonts/Inconsolata-Regular.ttf b/docs/build/html/_static/fonts/Inconsolata-Regular.ttf new file mode 100644 index 0000000..fc981ce Binary files /dev/null and b/docs/build/html/_static/fonts/Inconsolata-Regular.ttf differ diff --git a/docs/build/html/_static/fonts/Inconsolata.ttf b/docs/build/html/_static/fonts/Inconsolata.ttf new file mode 100644 index 0000000..4b8a36d Binary files /dev/null and b/docs/build/html/_static/fonts/Inconsolata.ttf differ diff --git a/docs/build/html/_static/fonts/Lato-Bold.ttf b/docs/build/html/_static/fonts/Lato-Bold.ttf new file mode 100644 index 0000000..1d23c70 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato-Bold.ttf differ diff --git a/docs/build/html/_static/fonts/Lato-Regular.ttf b/docs/build/html/_static/fonts/Lato-Regular.ttf new file mode 100644 index 0000000..0f3d0f8 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato-Regular.ttf differ diff --git a/docs/build/html/_static/fonts/Lato/lato-bold.eot b/docs/build/html/_static/fonts/Lato/lato-bold.eot new file mode 100644 index 0000000..3361183 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-bold.eot differ diff --git a/docs/build/html/_static/fonts/Lato/lato-bold.ttf b/docs/build/html/_static/fonts/Lato/lato-bold.ttf new file mode 100644 index 0000000..29f691d Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-bold.ttf differ diff --git a/docs/build/html/_static/fonts/Lato/lato-bold.woff b/docs/build/html/_static/fonts/Lato/lato-bold.woff new file mode 100644 index 0000000..c6dff51 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-bold.woff differ diff --git a/docs/build/html/_static/fonts/Lato/lato-bold.woff2 b/docs/build/html/_static/fonts/Lato/lato-bold.woff2 new file mode 100644 index 0000000..bb19504 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-bold.woff2 differ diff --git a/docs/build/html/_static/fonts/Lato/lato-bolditalic.eot b/docs/build/html/_static/fonts/Lato/lato-bolditalic.eot new file mode 100644 index 0000000..3d41549 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-bolditalic.eot differ diff --git a/docs/build/html/_static/fonts/Lato/lato-bolditalic.ttf b/docs/build/html/_static/fonts/Lato/lato-bolditalic.ttf new file mode 100644 index 0000000..f402040 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-bolditalic.ttf differ diff --git a/docs/build/html/_static/fonts/Lato/lato-bolditalic.woff b/docs/build/html/_static/fonts/Lato/lato-bolditalic.woff new file mode 100644 index 0000000..88ad05b Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-bolditalic.woff differ diff --git a/docs/build/html/_static/fonts/Lato/lato-bolditalic.woff2 b/docs/build/html/_static/fonts/Lato/lato-bolditalic.woff2 new file mode 100644 index 0000000..c4e3d80 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-bolditalic.woff2 differ diff --git a/docs/build/html/_static/fonts/Lato/lato-italic.eot b/docs/build/html/_static/fonts/Lato/lato-italic.eot new file mode 100644 index 0000000..3f82642 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-italic.eot differ diff --git a/docs/build/html/_static/fonts/Lato/lato-italic.ttf b/docs/build/html/_static/fonts/Lato/lato-italic.ttf new file mode 100644 index 0000000..b4bfc9b Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-italic.ttf differ diff --git a/docs/build/html/_static/fonts/Lato/lato-italic.woff b/docs/build/html/_static/fonts/Lato/lato-italic.woff new file mode 100644 index 0000000..76114bc Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-italic.woff differ diff --git a/docs/build/html/_static/fonts/Lato/lato-italic.woff2 b/docs/build/html/_static/fonts/Lato/lato-italic.woff2 new file mode 100644 index 0000000..3404f37 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-italic.woff2 differ diff --git a/docs/build/html/_static/fonts/Lato/lato-regular.eot b/docs/build/html/_static/fonts/Lato/lato-regular.eot new file mode 100644 index 0000000..11e3f2a Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-regular.eot differ diff --git a/docs/build/html/_static/fonts/Lato/lato-regular.ttf b/docs/build/html/_static/fonts/Lato/lato-regular.ttf new file mode 100644 index 0000000..74decd9 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-regular.ttf differ diff --git a/docs/build/html/_static/fonts/Lato/lato-regular.woff b/docs/build/html/_static/fonts/Lato/lato-regular.woff new file mode 100644 index 0000000..ae1307f Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-regular.woff differ diff --git a/docs/build/html/_static/fonts/Lato/lato-regular.woff2 b/docs/build/html/_static/fonts/Lato/lato-regular.woff2 new file mode 100644 index 0000000..3bf9843 Binary files /dev/null and b/docs/build/html/_static/fonts/Lato/lato-regular.woff2 differ diff --git a/docs/build/html/_static/fonts/RobotoSlab-Bold.ttf b/docs/build/html/_static/fonts/RobotoSlab-Bold.ttf new file mode 100644 index 0000000..df5d1df Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab-Bold.ttf differ diff --git a/docs/build/html/_static/fonts/RobotoSlab-Regular.ttf b/docs/build/html/_static/fonts/RobotoSlab-Regular.ttf new file mode 100644 index 0000000..eb52a79 Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab-Regular.ttf differ diff --git a/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot new file mode 100644 index 0000000..79dc8ef Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot differ diff --git a/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf new file mode 100644 index 0000000..df5d1df Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf differ diff --git a/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff new file mode 100644 index 0000000..6cb6000 Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff differ diff --git a/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 new file mode 100644 index 0000000..7059e23 Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 differ diff --git a/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot new file mode 100644 index 0000000..2f7ca78 Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot differ diff --git a/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf new file mode 100644 index 0000000..eb52a79 Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf differ diff --git a/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff new file mode 100644 index 0000000..f815f63 Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff differ diff --git a/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 new file mode 100644 index 0000000..f2c76e5 Binary files /dev/null and b/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 differ diff --git a/docs/build/html/_static/js/theme.js b/docs/build/html/_static/js/theme.js index 96672c6..8555d79 100644 --- a/docs/build/html/_static/js/theme.js +++ b/docs/build/html/_static/js/theme.js @@ -1,3 +1,3 @@ -/* sphinx_rtd_theme version 0.4.2 | MIT license */ -/* Built 20181005 13:10 */ -require=function r(s,a,l){function c(e,n){if(!a[e]){if(!s[e]){var i="function"==typeof require&&require;if(!n&&i)return i(e,!0);if(u)return u(e,!0);var t=new Error("Cannot find module '"+e+"'");throw t.code="MODULE_NOT_FOUND",t}var o=a[e]={exports:{}};s[e][0].call(o.exports,function(n){return c(s[e][1][n]||n)},o,o.exports,r,s,a,l)}return a[e].exports}for(var u="function"==typeof require&&require,n=0;n"),i("table.docutils.footnote").wrap("
"),i("table.docutils.citation").wrap("
"),i(".wy-menu-vertical ul").not(".simple").siblings("a").each(function(){var e=i(this);expand=i(''),expand.on("click",function(n){return t.toggleCurrent(e),n.stopPropagation(),!1}),e.prepend(expand)})},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),i=e.find('[href="'+n+'"]');if(0===i.length){var t=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(i=e.find('[href="#'+t.attr("id")+'"]')).length&&(i=e.find('[href="#"]'))}0this.docHeight||(this.navBar.scrollTop(i),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",function(){this.linkScroll=!1})},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current"),e.siblings().find("li.current").removeClass("current"),e.find("> ul li.current").removeClass("current"),e.toggleClass("current")}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:e.exports.ThemeNav,StickyNav:e.exports.ThemeNav}),function(){for(var r=0,n=["ms","moz","webkit","o"],e=0;e"),i("table.docutils.footnote").wrap("
"),i("table.docutils.citation").wrap("
"),i(".wy-menu-vertical ul").not(".simple").siblings("a").each(function(){var e=i(this);expand=i(''),expand.on("click",function(n){return t.toggleCurrent(e),n.stopPropagation(),!1}),e.prepend(expand)})},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),i=e.find('[href="'+n+'"]');if(0===i.length){var t=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(i=e.find('[href="#'+t.attr("id")+'"]')).length&&(i=e.find('[href="#"]'))}0this.docHeight||(this.navBar.scrollTop(i),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",function(){this.linkScroll=!1})},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current"),e.siblings().find("li.current").removeClass("current"),e.find("> ul li.current").removeClass("current"),e.toggleClass("current")}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:e.exports.ThemeNav,StickyNav:e.exports.ThemeNav}),function(){for(var r=0,n=["ms","moz","webkit","o"],e=0;e0 + var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 + var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 + var s_v = "^(" + C + ")?" + v; // vowel in stem + + this.stemWord = function (w) { + var stem; + var suffix; + var firstch; + var origword = w; + + if (w.length < 3) + return w; + + var re; + var re2; + var re3; + var re4; + + firstch = w.substr(0,1); + if (firstch == "y") + w = firstch.toUpperCase() + w.substr(1); + + // Step 1a + re = /^(.+?)(ss|i)es$/; + re2 = /^(.+?)([^s])s$/; + + if (re.test(w)) + w = w.replace(re,"$1$2"); + else if (re2.test(w)) + w = w.replace(re2,"$1$2"); + + // Step 1b + re = /^(.+?)eed$/; + re2 = /^(.+?)(ed|ing)$/; + if (re.test(w)) { + var fp = re.exec(w); + re = new RegExp(mgr0); + if (re.test(fp[1])) { + re = /.$/; + w = w.replace(re,""); + } + } + else if (re2.test(w)) { + var fp = re2.exec(w); + stem = fp[1]; + re2 = new RegExp(s_v); + if (re2.test(stem)) { + w = stem; + re2 = /(at|bl|iz)$/; + re3 = new RegExp("([^aeiouylsz])\\1$"); + re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); + if (re2.test(w)) + w = w + "e"; + else if (re3.test(w)) { + re = /.$/; + w = w.replace(re,""); + } + else if (re4.test(w)) + w = w + "e"; + } + } + + // Step 1c + re = /^(.+?)y$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + re = new RegExp(s_v); + if (re.test(stem)) + w = stem + "i"; + } + + // Step 2 + re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + suffix = fp[2]; + re = new RegExp(mgr0); + if (re.test(stem)) + w = stem + step2list[suffix]; + } + + // Step 3 + re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + suffix = fp[2]; + re = new RegExp(mgr0); + if (re.test(stem)) + w = stem + step3list[suffix]; + } + + // Step 4 + re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; + re2 = /^(.+?)(s|t)(ion)$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + re = new RegExp(mgr1); + if (re.test(stem)) + w = stem; + } + else if (re2.test(w)) { + var fp = re2.exec(w); + stem = fp[1] + fp[2]; + re2 = new RegExp(mgr1); + if (re2.test(stem)) + w = stem; + } + + // Step 5 + re = /^(.+?)e$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + re = new RegExp(mgr1); + re2 = new RegExp(meq1); + re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); + if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) + w = stem; + } + re = /ll$/; + re2 = new RegExp(mgr1); + if (re.test(w) && re2.test(w)) { + re = /.$/; + w = w.replace(re,""); + } + + // and turn initial Y back to y + if (firstch == "y") + w = firstch.toLowerCase() + w.substr(1); + return w; + } +} + + + + + +var splitChars = (function() { + var result = {}; + var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648, + 1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702, + 2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971, + 2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345, + 3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761, + 3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823, + 4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125, + 8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695, + 11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587, + 43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141]; + var i, j, start, end; + for (i = 0; i < singles.length; i++) { + result[singles[i]] = true; + } + var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709], + [722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161], + [1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568], + [1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807], + [1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047], + [2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383], + [2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450], + [2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547], + [2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673], + [2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820], + [2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946], + [2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023], + [3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173], + [3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332], + [3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481], + [3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718], + [3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791], + [3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095], + [4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205], + [4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687], + [4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968], + [4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869], + [5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102], + [6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271], + [6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592], + [6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822], + [6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167], + [7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959], + [7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143], + [8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318], + [8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483], + [8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101], + [10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567], + [11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292], + [12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444], + [12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783], + [12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311], + [19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511], + [42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774], + [42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071], + [43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263], + [43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519], + [43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647], + [43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967], + [44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295], + [57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274], + [64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007], + [65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381], + [65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]]; + for (i = 0; i < ranges.length; i++) { + start = ranges[i][0]; + end = ranges[i][1]; + for (j = start; j <= end; j++) { + result[j] = true; + } + } + return result; +})(); + +function splitQuery(query) { + var result = []; + var start = -1; + for (var i = 0; i < query.length; i++) { + if (splitChars[query.charCodeAt(i)]) { + if (start !== -1) { + result.push(query.slice(start, i)); + start = -1; + } + } else if (start === -1) { + start = i; + } + } + if (start !== -1) { + result.push(query.slice(start)); + } + return result; +} + + diff --git a/docs/build/html/_static/searchtools.js b/docs/build/html/_static/searchtools.js index 7473859..5ff3180 100644 --- a/docs/build/html/_static/searchtools.js +++ b/docs/build/html/_static/searchtools.js @@ -4,7 +4,7 @@ * * Sphinx JavaScript utilities for the full-text search. * - * :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ @@ -138,7 +138,6 @@ var Search = { */ query : function(query) { var i; - var stopwords = DOCUMENTATION_OPTIONS.SEARCH_LANGUAGE_STOP_WORDS; // stem the searchterms and add them to the correct list var stemmer = new Stemmer(); diff --git a/docs/build/html/_static/websupport.js b/docs/build/html/_static/websupport.js index 78e14bb..3b4999e 100644 --- a/docs/build/html/_static/websupport.js +++ b/docs/build/html/_static/websupport.js @@ -4,7 +4,7 @@ * * sphinx.websupport utilities for all documentation. * - * :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ diff --git a/docs/build/html/convlab.agent.algorithm.html b/docs/build/html/convlab.agent.algorithm.html new file mode 100644 index 0000000..c164f22 --- /dev/null +++ b/docs/build/html/convlab.agent.algorithm.html @@ -0,0 +1,1488 @@ + + + + + + + + + + + convlab.agent.algorithm package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.agent.algorithm package

+
+

Submodules

+
+
+

convlab.agent.algorithm.actor_critic module

+
+
+class convlab.agent.algorithm.actor_critic.ActorCritic(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.reinforce.Reinforce

+

Implementation of single threaded Advantage Actor Critic +Original paper: “Asynchronous Methods for Deep Reinforcement Learning” +https://arxiv.org/abs/1602.01783 +Algorithm specific spec param: +memory.name: batch (through OnPolicyBatchReplay memory class) or episodic through (OnPolicyReplay memory class) +lam: if not null, used as the lambda value of generalized advantage estimation (GAE) introduced in “High-Dimensional Continuous Control Using Generalized Advantage Estimation https://arxiv.org/abs/1506.02438. This lambda controls the bias variance tradeoff for GAE. Floating point value between 0 and 1. Lower values correspond to more bias, less variance. Higher values to more variance, less bias. Algorithm becomes A2C(GAE). +num_step_returns: if lam is null and this is not null, specifies the number of steps for N-step returns from “Asynchronous Methods for Deep Reinforcement Learning”. The algorithm becomes A2C(Nstep). +If both lam and num_step_returns are null, use the default TD error. Then the algorithm stays as AC. +net.type: whether the actor and critic should share params (e.g. through ‘MLPNetShared’) or have separate params (e.g. through ‘MLPNetSeparate’). If param sharing is used then there is also the option to control the weight given to the policy and value components of the loss function through ‘policy_loss_coef’ and ‘val_loss_coef’ +Algorithm - separate actor and critic:

+
+
+
Repeat:
+
    +
  1. Collect k examples
  2. +
  3. Train the critic network using these examples
  4. +
  5. Calculate the advantage of each example using the critic
  6. +
  7. Multiply the advantage by the negative of log probability of the action taken, and sum all the values. This is the policy loss.
  8. +
  9. Calculate the gradient the parameters of the actor network with respect to the policy loss
  10. +
  11. Update the actor network parameters using the gradient
  12. +
+
+
+
+
+
Algorithm - shared parameters:
+
+
Repeat:
+
    +
  1. Collect k examples
  2. +
  3. Calculate the target for each example for the critic
  4. +
  5. Compute current estimate of state-value for each example using the critic
  6. +
  7. Calculate the critic loss using a regression loss (e.g. square loss) between the target and estimate of the state-value for each example
  8. +
  9. Calculate the advantage of each example using the rewards and critic
  10. +
  11. Multiply the advantage by the negative of log probability of the action taken, and sum all the values. This is the policy loss.
  12. +
  13. Compute the total loss by summing the value and policy lossses
  14. +
  15. Calculate the gradient of the parameters of shared network with respect to the total loss
  16. +
  17. Update the shared network parameters using the gradient
  18. +
+
+
+
+
+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “ActorCritic”, +“action_pdtype”: “default”, +“action_policy”: “default”, +“explore_var_spec”: null, +“gamma”: 0.99, +“lam”: 1.0, +“num_step_returns”: 100, +“entropy_coef_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 0.01, +“end_val”: 0.001, +“start_step”: 100, +“end_step”: 5000,
+

}, +“policy_loss_coef”: 1.0, +“val_loss_coef”: 0.01, +“training_frequency”: 1,

+
+

}

+

e.g. special net_spec param “shared” to share/separate Actor/Critic +“net”: {

+
+
“type”: “MLPNet”, +“shared”: true, +…
+
+
+calc_gae_advs_v_targets(batch, v_preds)[source]
+

Calculate GAE, and advs = GAE, v_targets = advs + v_preds +See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf

+
+ +
+
+calc_nstep_advs_v_targets(batch, v_preds)[source]
+

Calculate N-step returns, and advs = nstep_rets - v_preds, v_targets = nstep_rets +See n-step advantage under http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_5_actor_critic_pdf.pdf

+
+ +
+
+calc_pdparam(x, net=None)[source]
+

The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.

+
+ +
+
+calc_pdparam_v(batch)[source]
+

Efficiently forward to get pdparam and v by batch for loss computation

+
+ +
+
+calc_policy_loss(batch, pdparams, advs)[source]
+

Calculate the actor’s policy loss

+
+ +
+
+calc_ret_advs_v_targets(batch, v_preds)[source]
+

Calculate plain returns, and advs = rets - v_preds, v_targets = rets

+
+ +
+
+calc_v(x, net=None, use_cache=True)[source]
+

Forward-pass to calculate the predicted state-value from critic_net.

+
+ +
+
+calc_val_loss(v_preds, v_targets)[source]
+

Calculate the critic’s value loss

+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters

+
+ +
+
+init_nets(global_nets=None)[source]
+

Initialize the neural networks used to learn the actor and critic from the spec +Below we automatically select an appropriate net based on two different conditions +1. If the action space is discrete or continuous action

+
+
    +
  • Networks for continuous action spaces have two heads and return two values, the first is a tensor containing the mean of the action policy, the second is a tensor containing the std deviation of the action policy. The distribution is assumed to be a Gaussian (Normal) distribution.
  • +
  • Networks for discrete action spaces have a single head and return the logits for a categorical probability distribution over the discrete actions
  • +
+
+
    +
  1. +
    If the actor and critic are separate or share weights
    +
      +
    • If the networks share weights then the single network returns a list.
    • +
    • Continuous action spaces: The return list contains 3 elements: The first element contains the mean output for the actor (policy), the second element the std dev of the policy, and the third element is the state-value estimated by the network.
    • +
    • Discrete action spaces: The return list contains 2 element. The first element is a tensor containing the logits for a categorical probability distribution over the actions. The second element contains the state-value estimated by the network.
    • +
    +
    +
    +
  2. +
  3. +
    If the network type is feedforward, convolutional, or recurrent
    +
      +
    • Feedforward and convolutional networks take a single state as input and require an OnPolicyReplay or OnPolicyBatchReplay memory
    • +
    • Recurrent networks take n states as input and require env spec “frame_op”: “concat”, “frame_op_len”: seq_len
    • +
    +
    +
    +
  4. +
+
+ +
+
+train()[source]
+

Train actor critic by computing the loss in batch efficiently

+
+ +
+
+update()[source]
+

Implement algorithm update, or throw NotImplementedError

+
+ +
+ +
+
+

convlab.agent.algorithm.base module

+
+
+class convlab.agent.algorithm.base.Algorithm(agent, global_nets=None)[source]
+

Bases: abc.ABC

+

Abstract class ancestor to all Algorithms, +specifies the necessary design blueprint for agent to work in Lab. +Mostly, implement just the abstract methods and properties.

+
+
+act(state)[source]
+

Standard act method.

+
+ +
+
+calc_pdparam(x, evaluate=True, net=None)[source]
+

To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs. +The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.

+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters

+
+ +
+
+init_nets(global_nets=None)[source]
+

Initialize the neural network from the spec

+
+ +
+
+load()[source]
+

Load net models for algorithm given the required property self.net_names

+
+ +
+
+nanflat_to_data_a(data_name, nanflat_data_a)[source]
+

Reshape nanflat_data_a, e.g. action_a, from a single pass back into the API-conforming data_a

+
+ +
+
+post_init_nets()[source]
+

Method to conditionally load models. +Call at the end of init_nets() after setting self.net_names

+
+ +
+
+sample()[source]
+

Samples a batch from memory

+
+ +
+
+save(ckpt=None)[source]
+

Save net models for algorithm given the required property self.net_names

+
+ +
+
+space_act(state_a)[source]
+

Interface-level agent act method for all its bodies. Resolves state to state; get action and compose into action.

+
+ +
+
+space_sample()[source]
+

Samples a batch from memory

+
+ +
+
+space_train()[source]
+
+ +
+
+space_update()[source]
+
+ +
+
+train()[source]
+

Implement algorithm train, or throw NotImplementedError

+
+ +
+
+update()[source]
+

Implement algorithm update, or throw NotImplementedError

+
+ +
+ +
+
+

convlab.agent.algorithm.dqn module

+
+
+class convlab.agent.algorithm.dqn.DQN(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.dqn.DQNBase

+

DQN class

+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “DQN”, +“action_pdtype”: “Argmax”, +“action_policy”: “epsilon_greedy”, +“explore_var_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 1.0, +“end_val”: 0.1, +“start_step”: 10, +“end_step”: 1000,
+

}, +“gamma”: 0.99, +“training_batch_iter”: 8, +“training_iter”: 4, +“training_frequency”: 10, +“training_start_step”: 10

+
+

}

+
+
+init_nets(global_nets=None)[source]
+

Initialize networks

+
+ +
+ +
+
+class convlab.agent.algorithm.dqn.DQNBase(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.dqn.VanillaDQN

+

Implementation of the base DQN algorithm. +The algorithm follows the same general approach as VanillaDQN but is more general since it allows +for two different networks (through self.net and self.target_net).

+

self.net is used to act, and is the network trained. +self.target_net is used to estimate the maximum value of the Q-function in the next state when calculating the target (see VanillaDQN comments). +self.target_net is updated periodically to either match self.net (self.net.update_type = “replace”) or to be a weighted average of self.net and the previous self.target_net (self.net.update_type = “polyak”) +If desired, self.target_net can be updated slowly, and this can help to stabilize learning.

+

It also allows for different nets to be used to select the action in the next state and to evaluate the value of that action through self.online_net and self.eval_net. This can help reduce the tendency of DQN’s to overestimate the value of the Q-function. Following this approach leads to the DoubleDQN algorithm.

+

Setting all nets to self.net reduces to the VanillaDQN case.

+
+
+calc_q_loss(batch)[source]
+

Compute the Q value loss using predicted and target Q values from the appropriate networks

+
+ +
+
+init_nets(global_nets=None)[source]
+

Initialize networks

+
+ +
+
+update()[source]
+

Updates self.target_net and the explore variables

+
+ +
+
+update_nets()[source]
+
+ +
+ +
+
+class convlab.agent.algorithm.dqn.DoubleDQN(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.dqn.DQN

+

Double-DQN (DDQN) class

+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “DDQN”, +“action_pdtype”: “Argmax”, +“action_policy”: “epsilon_greedy”, +“explore_var_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 1.0, +“end_val”: 0.1, +“start_step”: 10, +“end_step”: 1000,
+

}, +“gamma”: 0.99, +“training_batch_iter”: 8, +“training_iter”: 4, +“training_frequency”: 10, +“training_start_step”: 10

+
+

}

+
+
+init_nets(global_nets=None)[source]
+

Initialize networks

+
+ +
+ +
+
+class convlab.agent.algorithm.dqn.VanillaDQN(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.sarsa.SARSA

+

Implementation of a simple DQN algorithm. +Algorithm:

+
+
    +
  1. Collect some examples by acting in the environment and store them in a replay memory
  2. +
  3. Every K steps sample N examples from replay memory
  4. +
  5. +
    For each example calculate the target (bootstrapped estimate of the discounted value of the state and action taken), y, using a neural network to approximate the Q function. s’ is the next state following the action actually taken.
    +
    y_t = r_t + gamma * argmax_a Q(s_t’, a)
    +
    +
  6. +
  7. +
    For each example calculate the current estimate of the discounted value of the state and action taken
    +
    x_t = Q(s_t, a_t)
    +
    +
  8. +
  9. Calculate L(x, y) where L is a regression loss (eg. mse)
  10. +
  11. Calculate the gradient of L with respect to all the parameters in the network and update the network parameters using the gradient
  12. +
  13. Repeat steps 3 - 6 M times
  14. +
  15. Repeat steps 2 - 7 Z times
  16. +
  17. Repeat steps 1 - 8
  18. +
+
+

For more information on Q-Learning see Sergey Levine’s lectures 6 and 7 from CS294-112 Fall 2017 +https://www.youtube.com/playlist?list=PLkFD6_40KJIznC9CDbVTjAF2oyt8_VAe3

+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “VanillaDQN”, +“action_pdtype”: “Argmax”, +“action_policy”: “epsilon_greedy”, +“explore_var_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 1.0, +“end_val”: 0.1, +“start_step”: 10, +“end_step”: 1000,
+

}, +“gamma”: 0.99, +“training_batch_iter”: 8, +“training_iter”: 4, +“training_frequency”: 10, +“training_start_step”: 10,

+
+

}

+
+
+act(state)[source]
+

Selects and returns a discrete action for body using the action policy

+
+ +
+
+calc_q_loss(batch)[source]
+

Compute the Q value loss using predicted and target Q values from the appropriate networks

+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters.

+
+ +
+
+init_nets(global_nets=None)[source]
+

Initialize the neural network used to learn the Q function from the spec

+
+ +
+
+sample()[source]
+

Samples a batch from memory of size self.memory_spec[‘batch_size’]

+
+ +
+
+train()[source]
+

Completes one training step for the agent if it is time to train. +i.e. the environment timestep is greater than the minimum training timestep and a multiple of the training_frequency. +Each training step consists of sampling n batches from the agent’s memory. +For each of the batches, the target Q values (q_targets) are computed and a single training step is taken k times +Otherwise this function does nothing.

+
+ +
+
+update()[source]
+

Update the agent after training

+
+ +
+ +
+
+class convlab.agent.algorithm.dqn.WarmUpDQN(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.dqn.DQN

+

DQN class

+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “WarmUpDQN”, +“action_pdtype”: “Argmax”, +“action_policy”: “epsilon_greedy”, +“warmup_epi”: 300, +“explore_var_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 1.0, +“end_val”: 0.1, +“start_step”: 10, +“end_step”: 1000,
+

}, +“gamma”: 0.99, +“training_batch_iter”: 8, +“training_iter”: 4, +“training_frequency”: 10, +“training_start_step”: 10

+
+

}

+
+
+init_nets(global_nets=None)[source]
+

Initialize networks

+
+ +
+
+train()[source]
+

Completes one training step for the agent if it is time to train. +i.e. the environment timestep is greater than the minimum training timestep and a multiple of the training_frequency. +Each training step consists of sampling n batches from the agent’s memory. +For each of the batches, the target Q values (q_targets) are computed and a single training step is taken k times +Otherwise this function does nothing.

+
+ +
+
+warmup_sample()[source]
+

Samples a batch from warm-up memory

+
+ +
+ +
+
+

convlab.agent.algorithm.external module

+

The random agent algorithm +For basic dev purpose.

+
+
+class convlab.agent.algorithm.external.ExternalPolicy(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.base.Algorithm

+

Example Random agent that works in both discrete and continuous envs

+
+
+act(state)[source]
+

Standard act method.

+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters

+
+ +
+
+init_nets(global_nets=None)[source]
+

Initialize the neural network from the spec

+
+ +
+
+reset()[source]
+
+ +
+
+sample()[source]
+

Samples a batch from memory

+
+ +
+
+train()[source]
+

Implement algorithm train, or throw NotImplementedError

+
+ +
+
+update()[source]
+

Implement algorithm update, or throw NotImplementedError

+
+ +
+ +
+
+

convlab.agent.algorithm.policy_util module

+
+
+class convlab.agent.algorithm.policy_util.VarScheduler(var_decay_spec=None)[source]
+

Bases: object

+

Variable scheduler for decaying variables such as explore_var (epsilon, tau) and entropy

+

e.g. spec +“explore_var_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 1.0, +“end_val”: 0.1, +“start_step”: 0, +“end_step”: 800,
+

},

+
+
+update(algorithm, clock)[source]
+

Get an updated value for var

+
+ +
+ +
+
+convlab.agent.algorithm.policy_util.boltzmann(state, algorithm, body)[source]
+

Boltzmann policy: adjust pdparam with temperature tau; the higher the more randomness/noise in action.

+
+ +
+
+convlab.agent.algorithm.policy_util.calc_pdparam(state, algorithm, body)[source]
+

Prepare the state and run algorithm.calc_pdparam to get pdparam for action_pd +@param tensor:state For pdparam = net(state) +@param algorithm The algorithm containing self.net +@param body Body which links algorithm to the env which the action is for +@returns tensor:pdparam +@example

+

pdparam = calc_pdparam(state, algorithm, body) +action_pd = ActionPD(logits=pdparam) # e.g. ActionPD is Categorical +action = action_pd.sample()

+
+ +
+
+convlab.agent.algorithm.policy_util.default(state, algorithm, body)[source]
+

Plain policy by direct sampling from a default action probability defined by body.ActionPD

+
+ +
+
+convlab.agent.algorithm.policy_util.epsilon_greedy(state, algorithm, body)[source]
+

Epsilon-greedy policy: with probability epsilon, do random action, otherwise do default sampling.

+
+ +
+
+convlab.agent.algorithm.policy_util.get_action_pd_cls(action_pdtype, action_type)[source]
+

Verify and get the action prob. distribution class for construction +Called by body at init to set its own ActionPD

+
+ +
+
+convlab.agent.algorithm.policy_util.get_action_type(action_space)[source]
+

Method to get the action type to choose prob. dist. to sample actions from NN logits output

+
+ +
+
+convlab.agent.algorithm.policy_util.guard_tensor(state, body)[source]
+

Guard-cast tensor before being input to network

+
+ +
+
+convlab.agent.algorithm.policy_util.init_action_pd(ActionPD, pdparam)[source]
+

Initialize the action_pd for discrete or continuous actions: +- discrete: action_pd = ActionPD(logits) +- continuous: action_pd = ActionPD(loc, scale)

+
+ +
+
+convlab.agent.algorithm.policy_util.multi_boltzmann(states, algorithm, body_list, pdparam)[source]
+

Apply Boltzmann policy body-wise

+
+ +
+
+convlab.agent.algorithm.policy_util.multi_default(states, algorithm, body_list, pdparam)[source]
+

Apply default policy body-wise +Note, for efficiency, do a single forward pass to calculate pdparam, then call this policy like: +@example

+

pdparam = self.calc_pdparam(state) +action_a = self.action_policy(pdparam, self, body_list)

+
+ +
+
+convlab.agent.algorithm.policy_util.multi_epsilon_greedy(states, algorithm, body_list, pdparam)[source]
+

Apply epsilon-greedy policy body-wise

+
+ +
+
+convlab.agent.algorithm.policy_util.multi_random(states, algorithm, body_list, pdparam)[source]
+

Apply random policy body-wise.

+
+ +
+
+convlab.agent.algorithm.policy_util.random(state, algorithm, body)[source]
+

Random action using gym.action_space.sample(), with the same format as default()

+
+ +
+
+convlab.agent.algorithm.policy_util.rule_guide(state, algorithm, body)[source]
+
+ +
+
+convlab.agent.algorithm.policy_util.sample_action(ActionPD, pdparam)[source]
+

Convenience method to sample action(s) from action_pd = ActionPD(pdparam) +Works with batched pdparam too +@returns tensor:action Sampled action(s) +@example

+

# policy contains: +pdparam = calc_pdparam(state, algorithm, body) +action = sample_action(body.ActionPD, pdparam)

+
+ +
+
+convlab.agent.algorithm.policy_util.warmup_default(state, algorithm, body)[source]
+
+ +
+
+convlab.agent.algorithm.policy_util.warmup_epsilon_greedy(state, algorithm, body)[source]
+
+ +
+
+

convlab.agent.algorithm.ppo module

+
+
+class convlab.agent.algorithm.ppo.PPO(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.actor_critic.ActorCritic

+

Implementation of PPO +This is actually just ActorCritic with a custom loss function +Original paper: “Proximal Policy Optimization Algorithms” +https://arxiv.org/pdf/1707.06347.pdf

+

Adapted from OpenAI baselines, CPU version https://github.com/openai/baselines/tree/master/baselines/ppo1 +Algorithm: +for iteration = 1, 2, 3, … do

+
+
+
for actor = 1, 2, 3, …, N do
+
run policy pi_old in env for T timesteps +compute advantage A_1, …, A_T
+
+

end for +optimize surrogate L wrt theta, with K epochs and minibatch size M <= NT

+
+

end for

+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “PPO”, +“action_pdtype”: “default”, +“action_policy”: “default”, +“explore_var_spec”: null, +“gamma”: 0.99, +“lam”: 1.0, +“clip_eps_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 0.01, +“end_val”: 0.001, +“start_step”: 100, +“end_step”: 5000,
+

}, +“entropy_coef_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 0.01, +“end_val”: 0.001, +“start_step”: 100, +“end_step”: 5000,
+

}, +“minibatch_size”: 256, +“training_frequency”: 1, +“training_epoch”: 8,

+
+

}

+

e.g. special net_spec param “shared” to share/separate Actor/Critic +“net”: {

+
+
“type”: “MLPNet”, +“shared”: true, +…
+
+
+calc_policy_loss(batch, pdparams, advs)[source]
+

The PPO loss function (subscript t is omitted) +L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

+

Breakdown piecewise, +1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ] +where ratio = pi(a|s) / pi_old(a|s)

+
    +
  1. L^VF = E[ mse(V(s_t), V^target) ]
  2. +
  3. S = E[ entropy ]
  4. +
+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters

+
+ +
+
+init_nets(global_nets=None)[source]
+

PPO uses old and new to calculate ratio for loss

+
+ +
+
+train()[source]
+

Train actor critic by computing the loss in batch efficiently

+
+ +
+
+update()[source]
+

Implement algorithm update, or throw NotImplementedError

+
+ +
+ +
+
+

convlab.agent.algorithm.random module

+
+
+class convlab.agent.algorithm.random.Random(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.base.Algorithm

+

Example Random agent that works in both discrete and continuous envs

+
+
+act(state)[source]
+

Random action

+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters

+
+ +
+
+init_nets(global_nets=None)[source]
+

Initialize the neural network from the spec

+
+ +
+
+sample()[source]
+

Samples a batch from memory

+
+ +
+
+train()[source]
+

Implement algorithm train, or throw NotImplementedError

+
+ +
+
+update()[source]
+

Implement algorithm update, or throw NotImplementedError

+
+ +
+ +
+
+

convlab.agent.algorithm.reinforce module

+
+
+class convlab.agent.algorithm.reinforce.Reinforce(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.base.Algorithm

+

Implementation of REINFORCE (Williams, 1992) with baseline for discrete or continuous actions http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf +Adapted from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py +Algorithm:

+
+
    +
  1. Collect n episodes of data
  2. +
  3. +
    At each timestep in an episode
    +
      +
    • Calculate the advantage of that timestep
    • +
    • Multiply the advantage by the negative of the log probability of the action taken
    • +
    +
    +
    +
  4. +
  5. Sum all the values above.
  6. +
  7. Calculate the gradient of this value with respect to all of the parameters of the network
  8. +
  9. Update the network parameters using the gradient
  10. +
+
+

e.g. algorithm_spec: +“algorithm”: {

+
+

“name”: “Reinforce”, +“action_pdtype”: “default”, +“action_policy”: “default”, +“explore_var_spec”: null, +“gamma”: 0.99, +“entropy_coef_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 0.01, +“end_val”: 0.001, +“start_step”: 100, +“end_step”: 5000,
+

}, +“training_frequency”: 1,

+
+

}

+
+
+act(state)[source]
+

Standard act method.

+
+ +
+
+calc_pdparam(x, net=None)[source]
+

The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.

+
+ +
+
+calc_pdparam_batch(batch)[source]
+

Efficiently forward to get pdparam and by batch for loss computation

+
+ +
+
+calc_policy_loss(batch, pdparams, advs)[source]
+

Calculate the actor’s policy loss

+
+ +
+
+calc_ret_advs(batch)[source]
+

Calculate plain returns; which is generalized to advantage in ActorCritic

+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters

+
+ +
+
+init_nets(global_nets=None)[source]
+

Initialize the neural network used to learn the policy function from the spec +Below we automatically select an appropriate net for a discrete or continuous action space if the setting is of the form ‘MLPNet’. Otherwise the correct type of network is assumed to be specified in the spec. +Networks for continuous action spaces have two heads and return two values, the first is a tensor containing the mean of the action policy, the second is a tensor containing the std deviation of the action policy. The distribution is assumed to be a Gaussian (Normal) distribution. +Networks for discrete action spaces have a single head and return the logits for a categorical probability distribution over the discrete actions

+
+ +
+
+sample()[source]
+

Samples a batch from memory

+
+ +
+
+train()[source]
+

Implement algorithm train, or throw NotImplementedError

+
+ +
+
+update()[source]
+

Implement algorithm update, or throw NotImplementedError

+
+ +
+ +
+
+class convlab.agent.algorithm.reinforce.WarmUpReinforce(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.reinforce.Reinforce

+

Implementation of REINFORCE (Williams, 1992) with baseline for discrete or continuous actions http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf +Adapted from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py +Algorithm:

+
+
    +
  1. Collect n episodes of data
  2. +
  3. +
    At each timestep in an episode
    +
      +
    • Calculate the advantage of that timestep
    • +
    • Multiply the advantage by the negative of the log probability of the action taken
    • +
    +
    +
    +
  4. +
  5. Sum all the values above.
  6. +
  7. Calculate the gradient of this value with respect to all of the parameters of the network
  8. +
  9. Update the network parameters using the gradient
  10. +
+
+

e.g. algorithm_spec: +“algorithm”: {

+
+

“name”: “Reinforce”, +“action_pdtype”: “default”, +“action_policy”: “default”, +“warmup_epi”: 300, +“explore_var_spec”: null, +“gamma”: 0.99, +“entropy_coef_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 0.01, +“end_val”: 0.001, +“start_step”: 100, +“end_step”: 5000,
+

}, +“training_frequency”: 1,

+
+

}

+
+ +
+
+

convlab.agent.algorithm.sarsa module

+
+
+class convlab.agent.algorithm.sarsa.SARSA(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.base.Algorithm

+

Implementation of SARSA.

+

Algorithm: +Repeat:

+
+
    +
  1. Collect some examples by acting in the environment and store them in an on policy replay memory (either batch or episodic)
  2. +
  3. +
    For each example calculate the target (bootstrapped estimate of the discounted value of the state and action taken), y, using a neural network to approximate the Q function. s_t’ is the next state following the action actually taken, a_t. a_t’ is the action actually taken in the next state s_t’.
    +
    y_t = r_t + gamma * Q(s_t’, a_t’)
    +
    +
  4. +
+
    +
  1. +
    For each example calculate the current estimate of the discounted value of the state and action taken
    +
    x_t = Q(s_t, a_t)
    +
    +
  2. +
  3. Calculate L(x, y) where L is a regression loss (eg. mse)
  4. +
  5. Calculate the gradient of L with respect to all the parameters in the network and update the network parameters using the gradient
  6. +
+
+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “SARSA”, +“action_pdtype”: “default”, +“action_policy”: “boltzmann”, +“explore_var_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 1.0, +“end_val”: 0.1, +“start_step”: 10, +“end_step”: 1000,
+

}, +“gamma”: 0.99, +“training_frequency”: 10,

+
+

}

+
+
+act(state)[source]
+

Note, SARSA is discrete-only

+
+ +
+
+calc_pdparam(x, net=None)[source]
+

To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs. +The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.

+
+ +
+
+calc_q_loss(batch)[source]
+

Compute the Q value loss using predicted and target Q values from the appropriate networks

+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters.

+
+ +
+
+init_nets(global_nets=None)[source]
+

Initialize the neural network used to learn the Q function from the spec

+
+ +
+
+sample()[source]
+

Samples a batch from memory

+
+ +
+
+train()[source]
+

Completes one training step for the agent if it is time to train. +Otherwise this function does nothing.

+
+ +
+
+update()[source]
+

Update the agent after training

+
+ +
+ +
+
+

convlab.agent.algorithm.sil module

+
+
+class convlab.agent.algorithm.sil.PPOSIL(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.sil.SIL, convlab.agent.algorithm.ppo.PPO

+

SIL extended from PPO. This will call the SIL methods and use PPO as super().

+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “PPOSIL”, +“action_pdtype”: “default”, +“action_policy”: “default”, +“explore_var_spec”: null, +“gamma”: 0.99, +“lam”: 1.0, +“clip_eps_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 0.01, +“end_val”: 0.001, +“start_step”: 100, +“end_step”: 5000,
+

}, +“entropy_coef_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 0.01, +“end_val”: 0.001, +“start_step”: 100, +“end_step”: 5000,
+

}, +“sil_policy_loss_coef”: 1.0, +“sil_val_loss_coef”: 0.01, +“training_frequency”: 1, +“training_batch_iter”: 8, +“training_iter”: 8, +“training_epoch”: 8,

+
+

}

+

e.g. special memory_spec +“memory”: {

+
+
“name”: “OnPolicyReplay”, +“sil_replay_name”: “Replay”, +“batch_size”: 32, +“max_size”: 10000, +“use_cer”: true
+

}

+
+ +
+
+class convlab.agent.algorithm.sil.SIL(agent, global_nets=None)[source]
+

Bases: convlab.agent.algorithm.actor_critic.ActorCritic

+

Implementation of Self-Imitation Learning (SIL) https://arxiv.org/abs/1806.05635 +This is actually just A2C with an extra SIL loss function

+

e.g. algorithm_spec +“algorithm”: {

+
+

“name”: “SIL”, +“action_pdtype”: “default”, +“action_policy”: “default”, +“explore_var_spec”: null, +“gamma”: 0.99, +“lam”: 1.0, +“num_step_returns”: 100, +“entropy_coef_spec”: {

+
+
“name”: “linear_decay”, +“start_val”: 0.01, +“end_val”: 0.001, +“start_step”: 100, +“end_step”: 5000,
+

}, +“policy_loss_coef”: 1.0, +“val_loss_coef”: 0.01, +“sil_policy_loss_coef”: 1.0, +“sil_val_loss_coef”: 0.01, +“training_batch_iter”: 8, +“training_frequency”: 1, +“training_iter”: 8,

+
+

}

+

e.g. special memory_spec +“memory”: {

+
+
“name”: “OnPolicyReplay”, +“sil_replay_name”: “Replay”, +“batch_size”: 32, +“max_size”: 10000, +“use_cer”: true
+

}

+
+
+calc_sil_policy_val_loss(batch, pdparams)[source]
+

Calculate the SIL policy losses for actor and critic +sil_policy_loss = -log_prob * max(R - v_pred, 0) +sil_val_loss = (max(R - v_pred, 0)^2) / 2 +This is called on a randomly-sample batch from experience replay

+
+ +
+
+init_algorithm_params()[source]
+

Initialize other algorithm parameters

+
+ +
+
+replay_sample()[source]
+

Samples a batch from memory

+
+ +
+
+sample()[source]
+

Modify the onpolicy sample to also append to replay

+
+ +
+
+train()[source]
+

Train actor critic by computing the loss in batch efficiently

+
+ +
+ +
+
+

Module contents

+

The algorithm module +Contains implementations of reinforcement learning algorithms. +Uses the nets module to build neural networks as the relevant function approximators

+
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.agent.html b/docs/build/html/convlab.agent.html new file mode 100644 index 0000000..f877627 --- /dev/null +++ b/docs/build/html/convlab.agent.html @@ -0,0 +1,385 @@ + + + + + + + + + + + convlab.agent package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.agent package

+ +
+

Module contents

+
+
+class convlab.agent.Agent(spec, body, a=None, global_nets=None)[source]
+

Bases: object

+

Agent abstraction; implements the API to interface with Env in SLM Lab +Contains algorithm, memory, body

+
+
+act(state)[source]
+

Standard act method from algorithm.

+
+ +
+
+close()[source]
+

Close and cleanup agent at the end of a session, e.g. save model

+
+ +
+
+save(ckpt=None)[source]
+

Save agent

+
+ +
+
+update(state, action, reward, next_state, done)[source]
+

Update per timestep after env transitions, e.g. memory, algorithm, update agent params, train net

+
+ +
+ +
+
+class convlab.agent.Body(env, agent_spec, aeb=(0, 0, 0))[source]
+

Bases: object

+

Body of an agent inside an environment, it: +- enables the automatic dimension inference for constructing network input/output +- acts as reference bridge between agent and environment (useful for multi-agent, multi-env) +- acts as non-gradient variable storage for monitoring and analysis

+
+
+calc_df_row(env)[source]
+

Calculate a row for updating train_df or eval_df.

+
+ +
+
+eval_ckpt(eval_env, avg_return, avg_len, avg_success)[source]
+

Checkpoint to update body.eval_df data

+
+ +
+
+get_log_prefix()[source]
+

Get the prefix for logging

+
+ +
+
+get_mean_lr()[source]
+

Gets the average current learning rate of the algorithm’s nets.

+
+ +
+
+log_metrics(metrics, df_mode)[source]
+

Log session metrics

+
+ +
+
+log_summary(df_mode)[source]
+

Log the summary for this body when its environment is done +@param str:df_mode ‘train’ or ‘eval’

+
+ +
+
+train_ckpt()[source]
+

Checkpoint to update body.train_df data

+
+ +
+
+update(state, action, reward, next_state, done)[source]
+

Interface update method for body at agent.update()

+
+ +
+ +
+
+class convlab.agent.DialogAgent(spec, body, a=None, global_nets=None)[source]
+

Bases: convlab.agent.Agent

+

Class for all Agents. +Standardizes the Agent design to work in Lab. +Access Envs properties by: Agents - AgentSpace - AEBSpace - EnvSpace - Envs

+
+
+act(obs)[source]
+

Standard act method from algorithm.

+
+ +
+
+action_decode(action, state)[source]
+
+ +
+
+close()[source]
+

Close and cleanup agent at the end of a session, e.g. save model

+
+ +
+
+get_env()[source]
+
+ +
+
+reset(obs)[source]
+

Do agent reset per session, such as memory pointer

+
+ +
+
+save(ckpt=None)[source]
+

Save agent

+
+ +
+
+state_update(obs, action)[source]
+
+ +
+
+update(obs, action, reward, next_obs, done)[source]
+

Update per timestep after env transitions, e.g. memory, algorithm, update agent params, train net

+
+ +
+ +
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.agent.memory.html b/docs/build/html/convlab.agent.memory.html new file mode 100644 index 0000000..4b2ac0d --- /dev/null +++ b/docs/build/html/convlab.agent.memory.html @@ -0,0 +1,535 @@ + + + + + + + + + + + convlab.agent.memory package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.agent.memory package

+
+

Submodules

+
+
+

convlab.agent.memory.base module

+
+
+class convlab.agent.memory.base.Memory(memory_spec, body)[source]
+

Bases: abc.ABC

+

Abstract Memory class to define the API methods

+
+
+reset()[source]
+

Method to fully reset the memory storage and related variables

+
+ +
+
+sample()[source]
+

Implement memory sampling mechanism

+
+ +
+
+update(state, action, reward, next_state, done)[source]
+

Implement memory update given the full info from the latest timestep. NOTE: guard for np.nan reward and done when individual env resets.

+
+ +
+ +
+
+

convlab.agent.memory.onpolicy module

+
+
+class convlab.agent.memory.onpolicy.OnPolicyBatchReplay(memory_spec, body)[source]
+

Bases: convlab.agent.memory.onpolicy.OnPolicyReplay

+

Same as OnPolicyReplay Memory with the following difference.

+

The memory does not have a fixed size. Instead the memory stores data from N experiences, where N is determined by the user. After N experiences or if an episode has ended, all of the examples are returned to the agent to learn from.

+

In contrast, OnPolicyReplay stores entire episodes and stores them in a nested structure. OnPolicyBatchReplay stores experiences in a flat structure.

+

e.g. memory_spec +“memory”: {

+
+
“name”: “OnPolicyBatchReplay”
+

} +* batch_size is training_frequency provided by algorithm_spec

+
+
+add_experience(state, action, reward, next_state, done)[source]
+

Interface helper method for update() to add experience to memory

+
+ +
+
+sample()[source]
+

Returns all the examples from memory in a single batch. Batch is stored as a dict. +Keys are the names of the different elements of an experience. Values are a list of the corresponding sampled elements +e.g. +batch = {

+
+
‘states’ : states, +‘actions’ : actions, +‘rewards’ : rewards, +‘next_states’: next_states, +‘dones’ : dones}
+
+ +
+ +
+
+class convlab.agent.memory.onpolicy.OnPolicyReplay(memory_spec, body)[source]
+

Bases: convlab.agent.memory.base.Memory

+

Stores agent experiences and returns them in a batch for agent training.

+
+
An experience consists of
+
    +
  • state: representation of a state
  • +
  • action: action taken
  • +
  • reward: scalar value
  • +
  • next state: representation of next state (should be same as state)
  • +
  • done: 0 / 1 representing if the current state is the last in an episode
  • +
+
+
+

The memory does not have a fixed size. Instead the memory stores data from N episodes, where N is determined by the user. After N episodes, all of the examples are returned to the agent to learn from.

+

When the examples are returned to the agent, the memory is cleared to prevent the agent from learning from off policy experiences. This memory is intended for on policy algorithms.

+
+
Differences vs. Replay memory:
+
    +
  • Experiences are nested into episodes. In Replay experiences are flat, and episode is not tracked
  • +
  • The entire memory constitues a batch. In Replay batches are sampled from memory.
  • +
  • The memory is cleared automatically when a batch is given to the agent.
  • +
+
+
+

e.g. memory_spec +“memory”: {

+
+
“name”: “OnPolicyReplay”
+

}

+
+
+add_experience(state, action, reward, next_state, done)[source]
+

Interface helper method for update() to add experience to memory

+
+ +
+
+get_most_recent_experience()[source]
+

Returns the most recent experience

+
+ +
+
+reset()[source]
+

Resets the memory. Also used to initialize memory vars

+
+ +
+
+sample()[source]
+

Returns all the examples from memory in a single batch. Batch is stored as a dict. +Keys are the names of the different elements of an experience. Values are nested lists of the corresponding sampled elements. Elements are nested into episodes +e.g. +batch = {

+
+
‘states’ : [[s_epi1], [s_epi2], …], +‘actions’ : [[a_epi1], [a_epi2], …], +‘rewards’ : [[r_epi1], [r_epi2], …], +‘next_states’: [[ns_epi1], [ns_epi2], …], +‘dones’ : [[d_epi1], [d_epi2], …]}
+
+ +
+
+update(state, action, reward, next_state, done)[source]
+

Interface method to update memory

+
+ +
+ +
+
+

convlab.agent.memory.prioritized module

+
+
+class convlab.agent.memory.prioritized.PrioritizedReplay(memory_spec, body)[source]
+

Bases: convlab.agent.memory.replay.Replay

+

Prioritized Experience Replay

+

Implementation follows the approach in the paper “Prioritized Experience Replay”, Schaul et al 2015” https://arxiv.org/pdf/1511.05952.pdf and is Jaromír Janisch’s with minor adaptations. +See memory_util.py for the license and link to Jaromír’s excellent blog

+

Stores agent experiences and samples from them for agent training according to each experience’s priority

+

The memory has the same behaviour and storage structure as Replay memory with the addition of a SumTree to store and sample the priorities.

+

e.g. memory_spec +“memory”: {

+
+
“name”: “PrioritizedReplay”, +“alpha”: 1, +“epsilon”: 0, +“batch_size”: 32, +“max_size”: 10000, +“use_cer”: true
+

}

+
+
+add_experience(state, action, reward, next_state, done, error=100000)[source]
+

Implementation for update() to add experience to memory, expanding the memory size if necessary. +All experiences are added with a high priority to increase the likelihood that they are sampled at least once.

+
+ +
+
+get_priority(error)[source]
+

Takes in the error of one or more examples and returns the proportional priority

+
+ +
+
+reset()[source]
+

Initializes the memory arrays, size and head pointer

+
+ +
+
+sample_idxs(batch_size)[source]
+

Samples batch_size indices from memory in proportional to their priority.

+
+ +
+
+update_priorities(errors)[source]
+

Updates the priorities from the most recent batch +Assumes the relevant batch indices are stored in self.batch_idxs

+
+ +
+ +
+
+class convlab.agent.memory.prioritized.SumTree(capacity)[source]
+

Bases: object

+

Helper class for PrioritizedReplay

+

This implementation is, with minor adaptations, Jaromír Janisch’s. The license is reproduced below. +For more information see his excellent blog series “Let’s make a DQN” https://jaromiru.com/2016/09/27/lets-make-a-dqn-theory/

+

MIT License

+

Copyright (c) 2018 Jaromír Janisch

+

Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the “Software”), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions:

+
+
+add(p, index)[source]
+
+ +
+
+get(s)[source]
+
+ +
+
+print_tree()[source]
+
+ +
+
+total()[source]
+
+ +
+
+update(idx, p)[source]
+
+ +
+
+write = 0
+
+ +
+ +
+
+

convlab.agent.memory.replay module

+
+
+class convlab.agent.memory.replay.Replay(memory_spec, body)[source]
+

Bases: convlab.agent.memory.base.Memory

+

Stores agent experiences and samples from them for agent training

+
+
An experience consists of
+
    +
  • state: representation of a state
  • +
  • action: action taken
  • +
  • reward: scalar value
  • +
  • next state: representation of next state (should be same as state)
  • +
  • done: 0 / 1 representing if the current state is the last in an episode
  • +
+
+
+

The memory has a size of N. When capacity is reached, the oldest experience +is deleted to make space for the lastest experience.

+
+
    +
  • This is implemented as a circular buffer so that inserting experiences are O(1)
  • +
  • Each element of an experience is stored as a separate array of size N * element dim
  • +
+
+

When a batch of experiences is requested, K experiences are sampled according to a random uniform distribution.

+

If ‘use_cer’, sampling will add the latest experience.

+

e.g. memory_spec +“memory”: {

+
+
“name”: “Replay”, +“batch_size”: 32, +“max_size”: 10000, +“use_cer”: true
+

}

+
+
+add_experience(state, action, reward, next_state, done)[source]
+

Implementation for update() to add experience to memory, expanding the memory size if necessary

+
+ +
+
+reset()[source]
+

Initializes the memory arrays, size and head pointer

+
+ +
+
+sample()[source]
+

Returns a batch of batch_size samples. Batch is stored as a dict. +Keys are the names of the different elements of an experience. Values are an array of the corresponding sampled elements +e.g. +batch = {

+
+
‘states’ : states, +‘actions’ : actions, +‘rewards’ : rewards, +‘next_states’: next_states, +‘dones’ : dones}
+
+ +
+
+sample_idxs(batch_size)[source]
+

Batch indices a sampled random uniformly

+
+ +
+
+update(state, action, reward, next_state, done)[source]
+

Interface method to update memory

+
+ +
+ +
+
+convlab.agent.memory.replay.sample_next_states(head, max_size, ns_idx_offset, batch_idxs, states, ns_buffer)[source]
+

Method to sample next_states from states, with proper guard for next_state idx being out of bound

+
+ +
+
+

Module contents

+

The memory module +Contains different ways of storing an agents experiences and sampling from them

+
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.agent.net.html b/docs/build/html/convlab.agent.net.html new file mode 100644 index 0000000..19a5c04 --- /dev/null +++ b/docs/build/html/convlab.agent.net.html @@ -0,0 +1,747 @@ + + + + + + + + + + + convlab.agent.net package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.agent.net package

+
+

Submodules

+
+
+

convlab.agent.net.base module

+
+
+class convlab.agent.net.base.Net(net_spec, in_dim, out_dim)[source]
+

Bases: abc.ABC

+

Abstract Net class to define the API methods

+
+
+store_grad_norms()[source]
+

Stores the gradient norms for debugging.

+
+ +
+
+train_step(loss, optim, lr_scheduler, clock=None, global_net=None)[source]
+
+ +
+ +
+
+

convlab.agent.net.conv module

+
+
+class convlab.agent.net.conv.ConvNet(net_spec, in_dim, out_dim)[source]
+

Bases: convlab.agent.net.base.Net, torch.nn.modules.module.Module

+

Class for generating arbitrary sized convolutional neural network, +with optional batch normalization

+

Assumes that a single input example is organized into a 3D tensor. +The entire model consists of three parts:

+
+
    +
  1. self.conv_model
  2. +
  3. self.fc_model
  4. +
  5. self.model_tails
  6. +
+
+

e.g. net_spec +“net”: {

+
+

“type”: “ConvNet”, +“shared”: true, +“conv_hid_layers”: [

+
+
[32, 8, 4, 0, 1], +[64, 4, 2, 0, 1], +[64, 3, 1, 0, 1]
+

], +“fc_hid_layers”: [512], +“hid_layers_activation”: “relu”, +“out_layer_activation”: “tanh”, +“init_fn”: null, +“normalize”: false, +“batch_norm”: false, +“clip_grad_val”: 1.0, +“loss_spec”: {

+
+
“name”: “SmoothL1Loss”
+

}, +“optim_spec”: {

+
+
“name”: “Adam”, +“lr”: 0.02
+

}, +“lr_scheduler_spec”: {

+
+
“name”: “StepLR”, +“step_size”: 30, +“gamma”: 0.1
+

}, +“update_type”: “replace”, +“update_frequency”: 10000, +“polyak_coef”: 0.9, +“gpu”: true

+
+

}

+
+
+build_conv_layers(conv_hid_layers)[source]
+

Builds all of the convolutional layers in the network and store in a Sequential model

+
+ +
+
+forward(x)[source]
+

The feedforward step +Note that PyTorch takes (c,h,w) but gym provides (h,w,c), so preprocessing must be done before passing to network

+
+ +
+
+get_conv_output_size()[source]
+

Helper function to calculate the size of the flattened features after the final convolutional layer

+
+ +
+ +
+
+class convlab.agent.net.conv.DuelingConvNet(net_spec, in_dim, out_dim)[source]
+

Bases: convlab.agent.net.conv.ConvNet

+

Class for generating arbitrary sized convolutional neural network, +with optional batch normalization, and with dueling heads. Intended for Q-Learning algorithms only. +Implementation based on “Dueling Network Architectures for Deep Reinforcement Learning” http://proceedings.mlr.press/v48/wangf16.pdf

+

Assumes that a single input example is organized into a 3D tensor. +The entire model consists of three parts:

+
+
    +
  1. self.conv_model
  2. +
  3. self.fc_model
  4. +
  5. self.model_tails
  6. +
+
+

e.g. net_spec +“net”: {

+
+

“type”: “DuelingConvNet”, +“shared”: true, +“conv_hid_layers”: [

+
+
[32, 8, 4, 0, 1], +[64, 4, 2, 0, 1], +[64, 3, 1, 0, 1]
+

], +“fc_hid_layers”: [512], +“hid_layers_activation”: “relu”, +“init_fn”: “xavier_uniform_”, +“normalize”: false, +“batch_norm”: false, +“clip_grad_val”: 1.0, +“loss_spec”: {

+
+
“name”: “SmoothL1Loss”
+

}, +“optim_spec”: {

+
+
“name”: “Adam”, +“lr”: 0.02
+

}, +“lr_scheduler_spec”: {

+
+
“name”: “StepLR”, +“step_size”: 30, +“gamma”: 0.1
+

}, +“update_type”: “replace”, +“update_frequency”: 10000, +“polyak_coef”: 0.9, +“gpu”: true

+
+

}

+
+
+forward(x)[source]
+

The feedforward step

+
+ +
+ +
+
+

convlab.agent.net.mlp module

+
+
+class convlab.agent.net.mlp.DuelingMLPNet(net_spec, in_dim, out_dim)[source]
+

Bases: convlab.agent.net.mlp.MLPNet

+

Class for generating arbitrary sized feedforward neural network, with dueling heads. Intended for Q-Learning algorithms only. +Implementation based on “Dueling Network Architectures for Deep Reinforcement Learning” http://proceedings.mlr.press/v48/wangf16.pdf

+

e.g. net_spec +“net”: {

+
+

“type”: “DuelingMLPNet”, +“shared”: true, +“hid_layers”: [32], +“hid_layers_activation”: “relu”, +“init_fn”: “xavier_uniform_”, +“clip_grad_val”: 1.0, +“loss_spec”: {

+
+
“name”: “MSELoss”
+

}, +“optim_spec”: {

+
+
“name”: “Adam”, +“lr”: 0.02
+

}, +“lr_scheduler_spec”: {

+
+
“name”: “StepLR”, +“step_size”: 30, +“gamma”: 0.1
+

}, +“update_type”: “replace”, +“update_frequency”: 1, +“polyak_coef”: 0.9, +“gpu”: true

+
+

}

+
+
+forward(x)[source]
+

The feedforward step

+
+ +
+ +
+
+class convlab.agent.net.mlp.HydraMLPNet(net_spec, in_dim, out_dim)[source]
+

Bases: convlab.agent.net.base.Net, torch.nn.modules.module.Module

+

Class for generating arbitrary sized feedforward neural network with multiple state and action heads, and a single shared body.

+

e.g. net_spec +“net”: {

+
+

“type”: “HydraMLPNet”, +“shared”: true, +“hid_layers”: [

+
+
[[32],[32]], # 2 heads with hidden layers +[64], # body +[] # tail, no hidden layers
+

], +“hid_layers_activation”: “relu”, +“out_layer_activation”: null, +“init_fn”: “xavier_uniform_”, +“clip_grad_val”: 1.0, +“loss_spec”: {

+
+
“name”: “MSELoss”
+

}, +“optim_spec”: {

+
+
“name”: “Adam”, +“lr”: 0.02
+

}, +“lr_scheduler_spec”: {

+
+
“name”: “StepLR”, +“step_size”: 30, +“gamma”: 0.1
+

}, +“update_type”: “replace”, +“update_frequency”: 1, +“polyak_coef”: 0.9, +“gpu”: true

+
+

}

+
+
+build_model_heads(in_dim)[source]
+

Build each model_head. These are stored as Sequential models in model_heads

+
+ +
+
+build_model_tails(out_dim, out_layer_activation)[source]
+

Build each model_tail. These are stored as Sequential models in model_tails

+
+ +
+
+forward(xs)[source]
+

The feedforward step

+
+ +
+ +
+
+class convlab.agent.net.mlp.MLPNet(net_spec, in_dim, out_dim)[source]
+

Bases: convlab.agent.net.base.Net, torch.nn.modules.module.Module

+

Class for generating arbitrary sized feedforward neural network +If more than 1 output tensors, will create a self.model_tails instead of making last layer part of self.model

+

e.g. net_spec +“net”: {

+
+

“type”: “MLPNet”, +“shared”: true, +“hid_layers”: [32], +“hid_layers_activation”: “relu”, +“out_layer_activation”: null, +“init_fn”: “xavier_uniform_”, +“clip_grad_val”: 1.0, +“loss_spec”: {

+
+
“name”: “MSELoss”
+

}, +“optim_spec”: {

+
+
“name”: “Adam”, +“lr”: 0.02
+

}, +“lr_scheduler_spec”: {

+
+
“name”: “StepLR”, +“step_size”: 30, +“gamma”: 0.1
+

}, +“update_type”: “replace”, +“update_frequency”: 1, +“polyak_coef”: 0.9, +“gpu”: true

+
+

}

+
+
+forward(x)[source]
+

The feedforward step

+
+ +
+ +
+
+

convlab.agent.net.net_util module

+
+
+class convlab.agent.net.net_util.NoOpLRScheduler(optim)[source]
+

Bases: object

+

Symbolic LRScheduler class for API consistency

+
+
+get_lr()[source]
+
+ +
+
+step(epoch=None)[source]
+
+ +
+ +
+
+convlab.agent.net.net_util.build_fc_model(dims, activation=None)[source]
+

Build a full-connected model by interleaving nn.Linear and activation_fn

+
+ +
+
+convlab.agent.net.net_util.copy(src_net, tar_net)[source]
+

Copy model weights from src to target

+
+ +
+
+convlab.agent.net.net_util.dev_check_train_step(fn)[source]
+

Decorator to check if net.train_step actually updates the network weights properly +Triggers only if to_check_train_step is True (dev/test mode) +@example

+

@net_util.dev_check_train_step +def train_step(self, …):

+
+
+
+ +
+
+convlab.agent.net.net_util.get_activation_fn(activation)[source]
+

Helper to generate activation function layers for net

+
+ +
+
+convlab.agent.net.net_util.get_grad_norms(algorithm)[source]
+

Gather all the net’s grad norms of an algorithm for debugging

+
+ +
+
+convlab.agent.net.net_util.get_loss_fn(cls, loss_spec)[source]
+

Helper to parse loss param and construct loss_fn for net

+
+ +
+
+convlab.agent.net.net_util.get_lr_scheduler(optim, lr_scheduler_spec)[source]
+

Helper to parse lr_scheduler param and construct Pytorch optim.lr_scheduler

+
+ +
+
+convlab.agent.net.net_util.get_nn_name(uncased_name)[source]
+

Helper to get the proper name in PyTorch nn given a case-insensitive name

+
+ +
+
+convlab.agent.net.net_util.get_optim(net, optim_spec)[source]
+

Helper to parse optim param and construct optim for net

+
+ +
+
+convlab.agent.net.net_util.get_out_dim(body, add_critic=False)[source]
+

Construct the NetClass out_dim for a body according to is_discrete, action_type, and whether to add a critic unit

+
+ +
+
+convlab.agent.net.net_util.get_policy_out_dim(body)[source]
+

Helper method to construct the policy network out_dim for a body according to is_discrete, action_type

+
+ +
+
+convlab.agent.net.net_util.init_global_nets(algorithm)[source]
+

Initialize global_nets for Hogwild using an identical instance of an algorithm from an isolated Session +in spec.meta.distributed, specify either: +- ‘shared’: global network parameter is shared all the time. In this mode, algorithm local network will be replaced directly by global_net via overriding by identify attribute name +- ‘synced’: global network parameter is periodically synced to local network after each gradient push. In this mode, algorithm will keep a separate reference to global_{net} for each of its network

+
+ +
+
+convlab.agent.net.net_util.init_layers(net, init_fn_name)[source]
+

Primary method to initialize the weights of the layers of a network

+
+ +
+
+convlab.agent.net.net_util.init_params(module, init_fn)[source]
+

Initialize module’s weights using init_fn, and biases to 0.0

+
+ +
+
+convlab.agent.net.net_util.load(net, model_path)[source]
+

Save model weights from a path into a net module

+
+ +
+
+convlab.agent.net.net_util.load_algorithm(algorithm)[source]
+

Save all the nets for an algorithm

+
+ +
+
+convlab.agent.net.net_util.polyak_update(src_net, tar_net, old_ratio=0.5)[source]
+

Polyak weight update to update a target tar_net, retain old weights by its ratio, i.e. +target <- old_ratio * source + (1 - old_ratio) * target

+
+ +
+
+convlab.agent.net.net_util.push_global_grads(net, global_net)[source]
+

Push gradients to global_net, call inside train_step between loss.backward() and optim.step()

+
+ +
+
+convlab.agent.net.net_util.save(net, model_path)[source]
+

Save model weights to path

+
+ +
+
+convlab.agent.net.net_util.save_algorithm(algorithm, ckpt=None)[source]
+

Save all the nets for an algorithm

+
+ +
+
+convlab.agent.net.net_util.set_global_nets(algorithm, global_nets)[source]
+

For Hogwild, set attr built in init_global_nets above. Use in algorithm init.

+
+ +
+
+convlab.agent.net.net_util.to_check_train_step()[source]
+

Condition for running assert_trained

+
+ +
+
+

convlab.agent.net.recurrent module

+
+
+class convlab.agent.net.recurrent.RecurrentNet(net_spec, in_dim, out_dim)[source]
+

Bases: convlab.agent.net.base.Net, torch.nn.modules.module.Module

+

Class for generating arbitrary sized recurrent neural networks which take a sequence of states as input.

+

Assumes that a single input example is organized into a 3D tensor +batch_size x seq_len x state_dim +The entire model consists of three parts:

+
+
    +
  1. self.fc_model (state processing)
  2. +
  3. self.rnn_model
  4. +
  5. self.model_tails
  6. +
+
+

e.g. net_spec +“net”: {

+
+

“type”: “RecurrentNet”, +“shared”: true, +“cell_type”: “GRU”, +“fc_hid_layers”: [], +“hid_layers_activation”: “relu”, +“out_layer_activation”: null, +“rnn_hidden_size”: 32, +“rnn_num_layers”: 1, +“bidirectional”: False, +“seq_len”: 4, +“init_fn”: “xavier_uniform_”, +“clip_grad_val”: 1.0, +“loss_spec”: {

+
+
“name”: “MSELoss”
+

}, +“optim_spec”: {

+
+
“name”: “Adam”, +“lr”: 0.01
+

}, +“lr_scheduler_spec”: {

+
+
“name”: “StepLR”, +“step_size”: 30, +“gamma”: 0.1
+

}, +“update_type”: “replace”, +“update_frequency”: 1, +“polyak_coef”: 0.9, +“gpu”: true

+
+

}

+
+
+forward(x)[source]
+

The feedforward step. Input is batch_size x seq_len x state_dim

+
+ +
+ +
+
+

Module contents

+

The nets module +Contains classes of neural network architectures

+
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.env.html b/docs/build/html/convlab.env.html new file mode 100644 index 0000000..14112a9 --- /dev/null +++ b/docs/build/html/convlab.env.html @@ -0,0 +1,789 @@ + + + + + + + + + + + convlab.env package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.env package

+
+

Submodules

+
+
+

convlab.env.base module

+
+
+class convlab.env.base.BaseEnv(spec, e=None)[source]
+

Bases: abc.ABC

+

The base Env class with API and helper methods. Use this to implement your env class that is compatible with the Lab APIs

+

e.g. env_spec +“env”: [{

+
+
“name”: “PongNoFrameskip-v4”, +“frame_op”: “concat”, +“frame_op_len”: 4, +“normalize_state”: false, +“reward_scale”: “sign”, +“num_envs”: 8, +“max_t”: null, +“max_frame”: 1e7
+

}],

+
+
+close()[source]
+

Method to close and cleanup env

+
+ +
+
+reset()[source]
+

Reset method, return state

+
+ +
+
+step(action)[source]
+

Step method, return state, reward, done, info

+
+ +
+ +
+
+class convlab.env.base.Clock(max_frame=10000000, clock_speed=1)[source]
+

Bases: object

+

Clock class for each env and space to keep track of relative time. Ticking and control loop is such that reset is at t=0 and epi=0

+
+
+get(unit='frame')[source]
+
+ +
+
+get_elapsed_wall_t()[source]
+

Calculate the elapsed wall time (int seconds) since self.start_wall_t

+
+ +
+
+reset()[source]
+
+ +
+
+set_batch_size(batch_size)[source]
+
+ +
+
+tick(unit='t')[source]
+
+ +
+ +
+
+convlab.env.base.set_gym_space_attr(gym_space)[source]
+

Set missing gym space attributes for standardization

+
+ +
+
+

convlab.env.movie module

+
+
+class convlab.env.movie.KBHelper(movie_dictionary)[source]
+

Bases: object

+

An assistant to fill in values for the agent (which knows about slots of values)

+
+
+available_results_from_kb(current_slots)[source]
+

Return the available movies in the movie_kb based on the current constraints

+
+ +
+
+available_results_from_kb_for_slots(inform_slots)[source]
+

Return the count statistics for each constraint in inform_slots

+
+ +
+
+available_slot_values(slot, kb_results)[source]
+

Return the set of values available for the slot based on the current constraints

+
+ +
+
+database_results_for_agent(current_slots)[source]
+

A dictionary of the number of results matching each current constraint. The agent needs this to decide what to do next.

+
+ +
+
+fill_inform_slots(inform_slots_to_be_filled, current_slots)[source]
+

Takes unfilled inform slots and current_slots, returns dictionary of filled informed slots (with values)

+

Arguments: +inform_slots_to_be_filled – Something that looks like {starttime:None, theater:None} where starttime and theater are slots that the agent needs filled +current_slots – Contains a record of all filled slots in the conversation so far - for now, just use current_slots[‘inform_slots’] which is a dictionary of the already filled-in slots

+

Returns: +filled_in_slots – A dictionary of form {slot1:value1, slot2:value2} for each sloti in inform_slots_to_be_filled

+
+ +
+
+suggest_slot_values(request_slots, current_slots)[source]
+

Return the suggest slot values

+
+ +
+ +
+
+class convlab.env.movie.MovieActInActOutEnvironment(worker_id=None)[source]
+

Bases: object

+
+
+action_decode(action)[source]
+

DQN: Input state, output action

+
+ +
+
+action_index(act_slot_response)[source]
+

Return the index of action

+
+ +
+
+close()[source]
+
+ +
+
+initialize_episode()[source]
+

Initialize a new episode. This function is called every time a new episode is run.

+
+ +
+
+prepare_state_representation(state)[source]
+

Create the representation for each state

+
+ +
+
+print_function(agent_action=None, user_action=None)[source]
+

Print Function

+
+ +
+
+reset(train_mode, config)[source]
+
+ +
+
+reward_function(dialog_status)[source]
+

Reward Function 1: a reward function based on the dialog_status

+
+ +
+
+reward_function_without_penalty(dialog_status)[source]
+

Reward Function 2: a reward function without penalty on per turn and failure dialog

+
+ +
+
+rule_policy()[source]
+

Rule Policy

+
+ +
+
+step(action)[source]
+
+ +
+ +
+
+class convlab.env.movie.MovieEnv(spec, e=None, env_space=None)[source]
+

Bases: convlab.env.base.BaseEnv

+

Wrapper for Unity ML-Agents env to work with the Lab.

+

e.g. env_spec +“env”: [{

+
+

“name”: “gridworld”, +“max_t”: 20, +“max_tick”: 3, +“unity”: {

+
+
“gridSize”: 6, +“numObstacles”: 2, +“numGoals”: 1
+

}

+
+

}],

+
+
+close()[source]
+

Method to close and cleanup env

+
+ +
+
+patch_gym_spaces(u_env)[source]
+

For standardization, use gym spaces to represent observation and action spaces. +This method iterates through the multiple brains (multiagent) then constructs and returns lists of observation_spaces and action_spaces

+
+ +
+
+reset()[source]
+

Reset method, return state

+
+ +
+
+space_init(env_space)[source]
+

Post init override for space env. Note that aeb is already correct from __init__

+
+ +
+
+space_reset()[source]
+
+ +
+
+space_step(action_e)[source]
+
+ +
+
+step(action)[source]
+

Step method, return state, reward, done, info

+
+ +
+ +
+
+class convlab.env.movie.RuleSimulator(movie_dict=None, act_set=None, slot_set=None, start_set=None, params=None)[source]
+

Bases: convlab.env.movie.UserSimulator

+

A rule-based user simulator for testing dialog policy

+
+
+corrupt(user_action)[source]
+

Randomly corrupt an action with error probs (slot_err_probability and slot_err_mode) on Slot and Intent (intent_err_probability).

+
+ +
+
+debug_falk_goal()[source]
+

Debug function: build a fake goal mannually (Can be moved in future)

+
+ +
+
+initialize_episode()[source]
+

Initialize a new episode (dialog) +state[‘history_slots’]: keeps all the informed_slots +state[‘rest_slots’]: keep all the slots (which is still in the stack yet)

+
+ +
+
+next(system_action)[source]
+

Generate next User Action based on last System Action

+
+ +
+
+response_confirm_answer(system_action)[source]
+

Response for Confirm_Answer (System Action)

+
+ +
+
+response_inform(system_action)[source]
+

Response for Inform (System Action)

+
+ +
+
+response_multiple_choice(system_action)[source]
+

Response for Multiple_Choice (System Action)

+
+ +
+
+response_request(system_action)[source]
+

Response for Request (System Action)

+
+ +
+
+response_thanks(system_action)[source]
+

Response for Thanks (System Action)

+
+ +
+ +
+
+class convlab.env.movie.State(state=None, reward=None, done=None)[source]
+

Bases: object

+
+ +
+
+class convlab.env.movie.StateTracker(act_set, slot_set, movie_dictionary)[source]
+

Bases: object

+

The state tracker maintains a record of which request slots are filled and which inform slots are filled

+
+
+dialog_history_dictionaries()[source]
+

Return the dictionary representation of the dialog history (includes values)

+
+ +
+
+dialog_history_vectors()[source]
+

Return the dialog history (both user and agent actions) in vector representation

+
+ +
+
+get_current_kb_results()[source]
+

get the kb_results for current state

+
+ +
+
+get_state_for_agent()[source]
+

Get the state representatons to send to agent

+
+ +
+
+get_suggest_slots_values(request_slots)[source]
+

Get the suggested values for request slots

+
+ +
+
+initialize_episode()[source]
+

Initialize a new episode (dialog), flush the current state and tracked slots

+
+ +
+
+kb_results_for_state()[source]
+

Return the information about the database results based on the currently informed slots

+
+ +
+
+update(agent_action=None, user_action=None)[source]
+

Update the state based on the latest action

+
+ +
+ +
+
+class convlab.env.movie.UserSimulator(movie_dict=None, act_set=None, slot_set=None, start_set=None, params=None)[source]
+

Bases: object

+

Parent class for all user sims to inherit from

+
+
+add_nl_to_action(user_action)[source]
+

Add NL to User Dia_Act

+
+ +
+
+initialize_episode()[source]
+

Initialize a new episode (dialog)

+
+ +
+
+next(system_action)[source]
+
+ +
+
+set_nlg_model(nlg_model)[source]
+
+ +
+
+set_nlu_model(nlu_model)[source]
+
+ +
+ +
+
+convlab.env.movie.text_to_dict(path)[source]
+

Read in a text file as a dictionary where keys are text and values are indices (line numbers)

+
+ +
+
+

convlab.env.multiwoz module

+
+
+class convlab.env.multiwoz.MultiWozEnv(spec, e=None)[source]
+

Bases: convlab.env.base.BaseEnv

+

Wrapper for Unity ML-Agents env to work with the Lab.

+

e.g. env_spec +“env”: [{

+
+

“name”: “gridworld”, +“max_t”: 20, +“max_tick”: 3, +“unity”: {

+
+
“gridSize”: 6, +“numObstacles”: 2, +“numGoals”: 1
+

}

+
+

}],

+
+
+close()[source]
+

Method to close and cleanup env

+
+ +
+
+get_goal()[source]
+
+ +
+
+get_last_act()[source]
+
+ +
+
+get_sys_act()[source]
+
+ +
+
+get_task_success()[source]
+
+ +
+
+patch_gym_spaces(u_env)[source]
+

For standardization, use gym spaces to represent observation and action spaces. +This method iterates through the multiple brains (multiagent) then constructs and returns lists of observation_spaces and action_spaces

+
+ +
+
+reset()[source]
+

Reset method, return state

+
+ +
+
+step(action)[source]
+

Step method, return state, reward, done, info

+
+ +
+ +
+
+class convlab.env.multiwoz.MultiWozEnvironment(env_spec, worker_id=None, action_dim=300)[source]
+

Bases: object

+
+
+close()[source]
+
+ +
+
+get_goal()[source]
+
+ +
+
+get_last_act()[source]
+
+ +
+
+get_sys_act()[source]
+
+ +
+
+reset(train_mode, config)[source]
+
+ +
+
+rule_policy(state, algorithm, body)[source]
+
+ +
+
+step(action)[source]
+
+ +
+ +
+
+class convlab.env.multiwoz.State(state=None, reward=None, done=None)[source]
+

Bases: object

+
+ +
+
+

Module contents

+

The environment module +Contains graduated components from experiments for building/using environment. +Provides the rich experience for agent embodiment, reflects the curriculum and allows teaching (possibly allows teacher to enter). +To be designed by human and evolution module, based on the curriculum and fitness metrics.

+
+
+class convlab.env.EnvSpace(spec, aeb_space)[source]
+

Bases: object

+

Subspace of AEBSpace, collection of all envs, with interface to Session logic; same methods as singleton envs. +Access AgentSpace properties by: AgentSpace - AEBSpace - EnvSpace - Envs

+
+
+close()[source]
+
+ +
+
+get(e)[source]
+
+ +
+
+get_base_clock()[source]
+

Get the clock with the finest time unit, i.e. ticks the most cycles in a given time, or the highest clock_speed

+
+ +
+
+reset()[source]
+
+ +
+
+step(action_space)[source]
+
+ +
+ +
+
+convlab.env.make_env(spec, e=None)[source]
+
+ +
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.evaluator.html b/docs/build/html/convlab.evaluator.html new file mode 100644 index 0000000..b26c7f1 --- /dev/null +++ b/docs/build/html/convlab.evaluator.html @@ -0,0 +1,312 @@ + + + + + + + + + + + convlab.evaluator package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.evaluator package

+
+

Submodules

+
+
+

convlab.evaluator.evaluator module

+
+
+class convlab.evaluator.evaluator.Evaluator[source]
+

Bases: object

+
+
+add_goal(goal)[source]
+

init goal and array +:param goal: dict[domain] dict[‘info’/’book’/’reqt’] dict/dict/list[slot]

+
+ +
+
+add_sys_da(da_turn)[source]
+

add sys_da into array +:param da_turn: dict[domain-intent] list[slot, value]

+
+ +
+
+add_usr_da(da_turn)[source]
+

add usr_da into array +:param da_turn: dict[domain-intent] list[slot, value]

+
+ +
+
+book_rate(ref2goal=True, aggregate=True)[source]
+

judge if the selected entity meets the constraint

+
+ +
+
+domain_success(domain, ref2goal=True)[source]
+

judge if the domain (subtask) is successfully completed

+
+ +
+
+inform_F1(ref2goal=True, aggregate=True)[source]
+

judge if all the requested information is answered

+
+ +
+
+task_success(ref2goal=True)[source]
+

judge if all the domains are successfully completed

+
+ +
+ +
+
+

convlab.evaluator.multiwoz module

+
+
+class convlab.evaluator.multiwoz.MultiWozEvaluator[source]
+

Bases: convlab.evaluator.evaluator.Evaluator

+
+
+add_goal(goal)[source]
+

init goal and array +:param goal: dict[domain] dict[‘info’/’book’/’reqt’] dict/dict/list[slot]

+
+ +
+
+add_sys_da(da_turn)[source]
+

add sys_da into array +:param da_turn: dict[domain-intent] list[slot, value]

+
+ +
+
+add_usr_da(da_turn)[source]
+

add usr_da into array +:param da_turn: dict[domain-intent] list[slot, value]

+
+ +
+
+book_rate(ref2goal=True, aggregate=True)[source]
+

judge if the selected entity meets the constraint

+
+ +
+
+domain_success(domain, ref2goal=True)[source]
+

judge if the domain (subtask) is successfully completed

+
+ +
+
+inform_F1(ref2goal=True, aggregate=True)[source]
+

judge if all the requested information is answered

+
+ +
+
+task_success(ref2goal=True)[source]
+

judge if all the domains are successfully completed

+
+ +
+ +
+
+

Module contents

+
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.experiment.html b/docs/build/html/convlab.experiment.html new file mode 100644 index 0000000..0f646e3 --- /dev/null +++ b/docs/build/html/convlab.experiment.html @@ -0,0 +1,311 @@ + + + + + + + + + + + convlab.experiment package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.experiment package

+
+

Submodules

+
+
+

convlab.experiment.analysis module

+
+
+convlab.experiment.analysis.analyze_experiment(spec, trial_data_dict)[source]
+

Analyze experiment and save data

+
+ +
+
+convlab.experiment.analysis.analyze_session(session_spec, session_df, df_mode)[source]
+

Analyze session and save data, then return metrics. Note there are 2 types of session_df: body.eval_df and body.train_df

+
+ +
+
+convlab.experiment.analysis.analyze_trial(trial_spec, session_metrics_list)[source]
+

Analyze trial and save data, then return metrics

+
+ +
+
+convlab.experiment.analysis.calc_experiment_df(trial_data_dict, info_prepath=None)[source]
+

Collect all trial data (metrics and config) from trials into a dataframe

+
+ +
+
+convlab.experiment.analysis.calc_session_metrics(session_df, env_name, info_prepath=None, df_mode=None)[source]
+

Calculate the session metrics: strength, efficiency, stability +@param DataFrame:session_df Dataframe containing reward, frame, opt_step +@param str:env_name Name of the environment to get its random baseline +@param str:info_prepath Optional info_prepath to auto-save the output to +@param str:df_mode Optional df_mode to save with info_prepath +@returns dict:metrics Consists of scalar metrics and series local metrics

+
+ +
+
+convlab.experiment.analysis.calc_trial_metrics(session_metrics_list, info_prepath=None)[source]
+

Calculate the trial metrics: mean(strength), mean(efficiency), mean(stability), consistency +@param list:session_metrics_list The metrics collected from each session; format: {session_index: {‘scalar’: {…}, ‘local’: {…}}} +@param str:info_prepath Optional info_prepath to auto-save the output to +@returns dict:metrics Consists of scalar metrics and series local metrics

+
+ +
+
+convlab.experiment.analysis.gen_avg_result(agent, env, num_eval=4)[source]
+
+ +
+
+convlab.experiment.analysis.gen_avg_return(agent, env, num_eval=4)[source]
+

Generate average return for agent and an env

+
+ +
+
+convlab.experiment.analysis.gen_result(agent, env)[source]
+

Generate average return for agent and an env

+
+ +
+
+convlab.experiment.analysis.gen_return(agent, env)[source]
+

Generate return for an agent and an env in eval mode

+
+ +
+
+

convlab.experiment.control module

+
+
+

convlab.experiment.retro_analysis module

+
+
+convlab.experiment.retro_analysis.retro_analyze(predir)[source]
+

Method to analyze experiment/trial from files after it ran. +@example

+

yarn retro_analyze data/reinforce_cartpole_2018_01_22_211751/

+
+ +
+
+convlab.experiment.retro_analysis.retro_analyze_experiment(predir)[source]
+

Retro analyze an experiment

+
+ +
+
+convlab.experiment.retro_analysis.retro_analyze_sessions(predir)[source]
+

Retro analyze all sessions

+
+ +
+
+convlab.experiment.retro_analysis.retro_analyze_trials(predir)[source]
+

Retro analyze all trials

+
+ +
+
+

convlab.experiment.search module

+
+
+

Module contents

+
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.html b/docs/build/html/convlab.html new file mode 100644 index 0000000..35556d8 --- /dev/null +++ b/docs/build/html/convlab.html @@ -0,0 +1,636 @@ + + + + + + + + + + + convlab package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab package

+
+

Subpackages

+
+ +
+
+
+

Module contents

+
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.human_eval.html b/docs/build/html/convlab.human_eval.html new file mode 100644 index 0000000..d857a0c --- /dev/null +++ b/docs/build/html/convlab.human_eval.html @@ -0,0 +1,264 @@ + + + + + + + + + + + convlab.human_eval package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.human_eval package

+
+

Submodules

+
+
+

convlab.human_eval.analysis module

+
+
+convlab.human_eval.analysis.main()[source]
+

This task consists of an MTurk agent evaluating a chit-chat model. They +are asked to chat to the model adopting a specific persona. After their +conversation, they are asked to evaluate their partner on several metrics.

+
+ +
+
+

convlab.human_eval.bot_server module

+
+
+

convlab.human_eval.cambot_server module

+
+
+

convlab.human_eval.dqnbot_server module

+
+
+

convlab.human_eval.rulebot_server module

+
+
+

convlab.human_eval.run module

+
+
+

convlab.human_eval.sequicity_server module

+
+
+convlab.human_eval.sequicity_server.generate_response(in_queue, out_queue)[source]
+
+ +
+
+convlab.human_eval.sequicity_server.process()[source]
+
+ +
+
+

convlab.human_eval.task_config module

+
+
+convlab.human_eval.task_config.task_config = {'hit_description': 'You will chat to a tour information bot and then evaluate that bot.', 'hit_keywords': 'chat,dialog', 'hit_title': 'Chat and evaluate bot!', 'task_description': '\n (You can keep accepting new HITs after you finish your current one, so keep working on it if you like the task!)\n <br>\n <b>In this task you will chat with an information desk clerk bot to plan your tour according to a given goal.</b>\n <br>\n For example, your given goal and expected conversation could be: <br><br> \n <table border="1" cellpadding="10">\n <tr><th>Your goal</th><th>Expected conversation</th></tr>\n <tr><td>\n <ul>\n <li>You are looking for a <b>place to stay</b>. The hotel should be in the <b>cheap</b> price range and should be in the type of <b>hotel</b></li>\n <li>The hotel should include <b>free parking</b> and should include <b>free wifi</b></li>\n <li>Once you find the hotel, you want to book it for <b>6</b> people and <b>3</b> nights</b> starting from <b>tuesday</b></li>\n <li>If the booking fails how about <b>2</b> nights</li>\n <li>Make sure you get the <b>reference number</b></li>\n </ul>\n </td>\n <td>\n <b>You: </b>I am looking for a place to to stay that has cheap price range it should be in a type of hotel<br>\n <b>Info desk: </b>Okay, do you have a specific area you want to stay in?<br>\n <b>You: </b>no, i just need to make sure it\'s cheap. oh, and i need parking<br>\n <b>Info desk: </b>I found 1 cheap hotel for you that includes parking. Do you like me to book it?<br>\n <b>You: </b>Yes, please. 6 people 3 nights starting on tuesday.<br>\n <b>Info desk: </b>I am sorry but I wasn\'t able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay?<br>\n <b>You: </b>how about only 2 nights.<br>\n <b>Info desk: </b>Booking was successful.\nReference number is : 7GAWK763. Anything else I can do for you?<br>\n <b>You: </b>No, that will be all. Good bye.<br>\n <b>Info desk: </b>Thank you for using our services.<br>\n </td>\n </table>\n <br><br>\n Chat with the bot naturally and stick to your own goal but <b>do not trivially copy the goal descriptions into the message.</b>\n <br>\n Once the conversation is done, you will be asked to rate the bot on metrics like <b>goal accomplishment, language understanding, and response naturalness</b>.\n <br>\n There is a <b>2 min</b> time limit for each turn.\n <br>\n <br>\n - Do not reference the task or MTurk itself during the conversation.\n <br>\n <b><span style="color:red">- No racism, sexism or otherwise offensive comments, or the submission will be rejected and we will report to Amazon.</b></span>\n <br>\n <br>\n '}
+

A short and descriptive title about the kind of task the HIT contains. +On the Amazon Mechanical Turk web site, the HIT title appears in search results, +and everywhere the HIT is mentioned.

+
+ +
+
+

convlab.human_eval.worlds module

+
+
+

Module contents

+
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/convlab.lib.html b/docs/build/html/convlab.lib.html new file mode 100644 index 0000000..fde0da2 --- /dev/null +++ b/docs/build/html/convlab.lib.html @@ -0,0 +1,1117 @@ + + + + + + + + + + + convlab.lib package — ConvLab 0.1 documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + +
+ +
+ + + + + + + + + + + + + + + + + +
+ + + + +
+
+
+
+ +
+

convlab.lib package

+
+

Submodules

+
+
+

convlab.lib.decorator module

+
+
+convlab.lib.decorator.lab_api(fn)[source]
+

Function decorator to label and check Lab API methods +@example

+

from convlab.lib.decorator import lab_api +@lab_api +def foo():

+
+
print(‘foo’)
+
+ +
+
+convlab.lib.decorator.timeit(fn)[source]
+

Function decorator to measure execution time +@example

+

from convlab.lib.decorator import timeit +@timeit +def foo(sec):

+
+
time.sleep(sec) +print(‘foo’)
+

foo(1) +# => foo +# => Timed: foo 1000.9971ms

+
+ +
+
+

convlab.lib.distribution module

+
+
+class convlab.lib.distribution.Argmax(probs=None, logits=None, validate_args=None)[source]
+

Bases: torch.distributions.categorical.Categorical

+

Special distribution class for argmax sampling, where probability is always 1 for the argmax. +NOTE although argmax is not a sampling distribution, this implementation is for API consistency.

+
+ +
+
+class convlab.lib.distribution.GumbelCategorical(probs=None, logits=None, validate_args=None)[source]
+

Bases: torch.distributions.categorical.Categorical

+

Special Categorical using Gumbel distribution to simulate softmax categorical for discrete action. +Similar to OpenAI’s https://github.com/openai/baselines/blob/98257ef8c9bd23a24a330731ae54ed086d9ce4a7/baselines/a2c/utils.py#L8-L10 +Explanation http://amid.fish/assets/gumbel.html

+
+
+sample(sample_shape=torch.Size([]))[source]
+

Gumbel softmax sampling

+
+ +
+ +
+
+class convlab.lib.distribution.MultiCategorical(probs=None, logits=None, validate_args=None)[source]
+

Bases: torch.distributions.categorical.Categorical

+

MultiCategorical as collection of Categoricals

+
+
+entropy()[source]
+

Returns entropy of distribution, batched over batch_shape.

+ +++ + + + +
Returns:Tensor of shape batch_shape.
+
+ +
+
+enumerate_support()[source]
+

Returns tensor containing all values supported by a discrete +distribution. The result will enumerate over dimension 0, so the shape +of the result will be (cardinality,) + batch_shape + event_shape +(where event_shape = () for univariate distributions).

+

Note that this enumerates over all batched tensors in lock-step +[[0, 0], [1, 1], …]. With expand=False, enumeration happens +along dim 0, but with the remaining batch dimensions being +singleton dimensions, [[0], [1], ...

+

To iterate over the full Cartesian product use +itertools.product(m.enumerate_support()).

+ +++ + + + + + +
Parameters:expand (bool) – whether to expand the support over the +batch dims to match the distribution’s batch_shape.
Returns:Tensor iterating over dimension 0.
+
+ +
+
+log_prob(value)[source]
+

Returns the log of the probability density/mass function evaluated at +value.

+ +++ + + + +
Parameters:value (Tensor) –
+
+ +
+
+logits
+
+ +
+
+mean
+

Returns the mean of the distribution.

+
+ +
+
+param_shape
+
+ +
+
+probs
+
+ +
+
+sample(sample_shape=torch.Size([]))[source]
+

Generates a sample_shape shaped sample or sample_shape shaped batch of +samples if the distribution parameters are batched.

+
+ +
+
+variance
+

Returns the variance of the distribution.

+
+ +
+ +
+
+

convlab.lib.file_util module

+
+
+convlab.lib.file_util.cached_path(file_path, cached_dir=None)[source]
+
+ +
+
+

convlab.lib.logger module

+
+
+class convlab.lib.logger.FixedList[source]
+

Bases: list

+

fixed-list to restrict addition to root logger handler

+
+
+append(object) → None -- append object to end[source]
+
+ +
+ +
+
+convlab.lib.logger.act(msg, *args, **kwargs)[source]
+
+ +
+
+convlab.lib.logger.critical(msg, *args, **kwargs)[source]
+
+ +
+
+convlab.lib.logger.debug(msg, *args, **kwargs)[source]
+
+ +
+
+convlab.lib.logger.error(msg, *args, **kwargs)[source]
+
+ +
+
+convlab.lib.logger.exception(msg, *args, **kwargs)[source]
+
+ +
+
+convlab.lib.logger.get_logger(__name__)[source]
+

Create a child logger specific to a module

+
+ +
+
+convlab.lib.logger.info(msg, *args, **kwargs)[source]
+
+ +
+
+convlab.lib.logger.nl(msg, *args, **kwargs)[source]
+
+ +
+
+convlab.lib.logger.set_level(lvl)[source]
+
+ +
+
+convlab.lib.logger.state(msg, *args, **kwargs)[source]
+
+ +
+
+convlab.lib.logger.toggle_debug(modules, level='DEBUG')[source]
+

Turn on module-specific debugging using their names, e.g. algorithm, actor_critic, at the desired debug level.

+
+ +
+
+convlab.lib.logger.warning(msg, *args, **kwargs)[source]
+
+ +
+
+

convlab.lib.math_util module

+
+
+convlab.lib.math_util.calc_gaes(rewards, dones, v_preds, gamma, lam)[source]
+

Calculate GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf +v_preds are values predicted for current states, with one last element as the final next_state +delta is defined as r + gamma * V(s’) - V(s) in eqn 10 +GAE is defined in eqn 16 +This method computes in torch tensor to prevent unnecessary moves between devices (e.g. GPU tensor to CPU numpy) +NOTE any standardization is done outside of this method

+
+ +
+
+convlab.lib.math_util.calc_nstep_returns(rewards, dones, next_v_pred, gamma, n)[source]
+

Calculate the n-step returns for advantage. Ref: http://www-anw.cs.umass.edu/~barto/courses/cs687/Chapter%207.pdf +Also see Algorithm S3 from A3C paper https://arxiv.org/pdf/1602.01783.pdf for the calculation used below +R^(n)_t = r_{t} + gamma r_{t+1} + … + gamma^(n-1) r_{t+n-1} + gamma^(n) V(s_{t+n})

+
+ +
+
+convlab.lib.math_util.calc_q_value_logits(state_value, raw_advantages)[source]
+
+ +
+
+convlab.lib.math_util.calc_returns(rewards, dones, gamma)[source]
+

Calculate the simple returns (full rollout) i.e. sum discounted rewards up till termination

+
+ +
+
+convlab.lib.math_util.linear_decay(start_val, end_val, start_step, end_step, step)[source]
+

Simple linear decay with annealing

+
+ +
+
+convlab.lib.math_util.no_decay(start_val, end_val, start_step, end_step, step)[source]
+

dummy method for API consistency

+
+ +
+
+convlab.lib.math_util.normalize(v)[source]
+

Method to normalize a rank-1 np array

+
+ +
+
+convlab.lib.math_util.periodic_decay(start_val, end_val, start_step, end_step, step, frequency=60.0)[source]
+

Linearly decaying sinusoid that decays in roughly 10 iterations until explore_anneal_epi +Plot the equation below to see the pattern +suppose sinusoidal decay, start_val = 1, end_val = 0.2, stop after 60 unscaled x steps +then we get 0.2+0.5*(1-0.2)(1 + cos x)*(1-x/60)

+
+ +
+
+convlab.lib.math_util.rate_decay(start_val, end_val, start_step, end_step, step, decay_rate=0.9, frequency=20.0)[source]
+

Compounding rate decay that anneals in 20 decay iterations until end_step

+
+ +
+
+convlab.lib.math_util.standardize(v)[source]
+

Method to standardize a rank-1 np array

+
+ +
+
+convlab.lib.math_util.to_one_hot(data, max_val)[source]
+

Convert an int list of data into one-hot vectors

+
+ +
+
+convlab.lib.math_util.venv_pack(batch_tensor, num_envs)[source]
+

Apply the reverse of venv_unpack to pack a batch tensor from (b*num_envs, *shape) to (b, num_envs, *shape)

+
+ +
+
+convlab.lib.math_util.venv_unpack(batch_tensor)[source]
+

Unpack a sampled vec env batch tensor +e.g. for a state with original shape (4, ), vec env should return vec state with shape (num_envs, 4) to store in memory +When sampled with batch_size b, we should get shape (b, num_envs, 4). But we need to unpack the num_envs dimension to get (b * num_envs, 4) for passing to a network. This method does that.

+
+ +
+
+

convlab.lib.optimizer module

+
+
+class convlab.lib.optimizer.GlobalAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)[source]
+

Bases: torch.optim.adam.Adam

+

Global Adam algorithm with shared states for Hogwild. +Adapted from https://github.com/ikostrikov/pytorch-a3c/blob/master/my_optim.py (MIT)

+
+
+share_memory()[source]
+
+ +
+
+step(closure=None)[source]
+

Performs a single optimization step.

+ +++ + + + +
Parameters:closure (callable, optional) – A closure that reevaluates the model +and returns the loss.
+
+ +
+ +
+
+class convlab.lib.optimizer.GlobalRMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0)[source]
+

Bases: torch.optim.rmsprop.RMSprop

+

Global RMSprop algorithm with shared states for Hogwild. +Adapted from https://github.com/jingweiz/pytorch-rl/blob/master/optims/sharedRMSprop.py (MIT)

+
+
+share_memory()[source]
+
+ +
+
+step(closure=None)[source]
+

Performs a single optimization step.

+ +++ + + + +
Parameters:closure (callable, optional) – A closure that reevaluates the model +and returns the loss.
+
+ +
+ +
+
+

convlab.lib.util module

+
+
+class convlab.lib.util.LabJsonEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source]
+

Bases: json.encoder.JSONEncoder

+
+
+default(obj)[source]
+

Implement this method in a subclass such that it returns +a serializable object for o, or calls the base implementation +(to raise a TypeError).

+

For example, to support arbitrary iterators, you could +implement default like this:

+
def default(self, o):
+    try:
+        iterable = iter(o)
+    except TypeError:
+        pass
+    else:
+        return list(iterable)
+    # Let the base class default method raise the TypeError
+    return JSONEncoder.default(self, o)
+
+
+
+ +
+ +
+
+convlab.lib.util.batch_get(arr, idxs)[source]
+

Get multi-idxs from an array depending if it’s a python list or np.array

+
+ +
+
+convlab.lib.util.calc_srs_mean_std(sr_list)[source]
+

Given a list of series, calculate their mean and std

+
+ +
+
+convlab.lib.util.calc_ts_diff(ts2, ts1)[source]
+

Calculate the time from tss ts1 to ts2 +@param {str} ts2 Later ts in the FILE_TS_FORMAT +@param {str} ts1 Earlier ts in the FILE_TS_FORMAT +@returns {str} delta_t in %H:%M:%S format +@example

+

ts1 = ‘2017_10_17_084739’ +ts2 = ‘2017_10_17_084740’ +ts_diff = util.calc_ts_diff(ts2, ts1) +# => ‘0:00:01’

+
+ +
+
+convlab.lib.util.cast_df(val)[source]
+

missing pydash method to cast value as DataFrame

+
+ +
+
+convlab.lib.util.cast_list(val)[source]
+

missing pydash method to cast value as list

+
+ +
+
+convlab.lib.util.clear_periodic_ckpt(prepath)[source]
+

Clear periodic (with -epi) ckpt files in prepath

+
+ +
+
+convlab.lib.util.concat_batches(batches)[source]
+

Concat batch objects from body.memory.sample() into one batch, when all bodies experience similar envs +Also concat any nested epi sub-batches into flat batch +{k: arr1} + {k: arr2} = {k: arr1 + arr2}

+
+ +
+
+convlab.lib.util.ctx_lab_mode(lab_mode)[source]
+

Creates context to run method with a specific lab_mode +@example +with util.ctx_lab_mode(‘eval’):

+
+
foo()
+

@util.ctx_lab_mode(‘eval’) +def foo():

+
+
+
+ +
+
+convlab.lib.util.debug_image(im)[source]
+

Renders an image for debugging; pauses process until key press +Handles tensor/numpy and conventions among libraries

+
+ +
+
+convlab.lib.util.downcast_float32(df)[source]
+

Downcast any float64 col to float32 to allow safer pandas comparison

+
+ +
+
+convlab.lib.util.epi_done(done)[source]
+

General method to check if episode is done for both single and vectorized env +Only return True for singleton done since vectorized env does not have a natural episode boundary

+
+ +
+
+convlab.lib.util.find_ckpt(prepath)[source]
+

Find the ckpt-lorem-ipsum in a string and return lorem-ipsum

+
+ +
+
+convlab.lib.util.flatten_dict(obj, delim='.')[source]
+

Missing pydash method to flatten dict

+
+ +
+
+convlab.lib.util.frame_mod(frame, frequency, num_envs)[source]
+

Generic mod for (frame % frequency == 0) for when num_envs is 1 or more, +since frame will increase multiple ticks for vector env, use the remainder

+
+ +
+
+convlab.lib.util.get_class_attr(obj)[source]
+

Get the class attr of an object as dict

+
+ +
+
+convlab.lib.util.get_class_name(obj, lower=False)[source]
+

Get the class name of an object

+
+ +
+
+convlab.lib.util.get_file_ext(data_path)[source]
+

get the .ext of file.ext

+
+ +
+
+convlab.lib.util.get_fn_list(a_cls)[source]
+

Get the callable, non-private functions of a class +@returns {[*str]} A list of strings of fn names

+
+ +
+
+convlab.lib.util.get_git_sha()[source]
+
+ +
+
+convlab.lib.util.get_lab_mode()[source]
+
+ +
+
+convlab.lib.util.get_prepath(spec, unit='experiment')[source]
+
+ +
+
+convlab.lib.util.get_ts(pattern='%Y_%m_%d_%H%M%S')[source]
+

Get current ts, defaults to format used for filename +@param {str} pattern To format the ts +@returns {str} ts +@example

+

util.get_ts() +# => ‘2017_10_17_084739’

+
+ +
+
+convlab.lib.util.grayscale_image(im)[source]
+
+ +
+
+convlab.lib.util.in_eval_lab_modes()[source]
+

Check if lab_mode is one of EVAL_MODES

+
+ +
+
+convlab.lib.util.insert_folder(prepath, folder)[source]
+

Insert a folder into prepath

+
+ +
+
+convlab.lib.util.is_jupyter()[source]
+

Check if process is in Jupyter kernel

+
+ +
+
+convlab.lib.util.monkey_patch(base_cls, extend_cls)[source]
+

Monkey patch a base class with methods from extend_cls

+
+ +
+
+convlab.lib.util.mpl_debug_image(im)[source]
+

Uses matplotlib to plot image with bigger size, axes, and false color on greyscaled images

+
+ +
+
+convlab.lib.util.normalize_image(im)[source]
+

Normalizing image by dividing max value 255

+
+ +
+
+convlab.lib.util.parallelize(fn, args, num_cpus=24)[source]
+

Parallelize a method fn, args and return results with order preserved per args. +args should be a list of tuples. +@returns {list} results Order preserved output from fn.

+
+ +
+
+convlab.lib.util.prepath_split(prepath)[source]
+

Split prepath into useful names. Works with predir (prename will be None) +prepath: output/dqn_pong_2018_12_02_082510/dqn_pong_t0_s0 +predir: output/dqn_pong_2018_12_02_082510 +prefolder: dqn_pong_2018_12_02_082510 +prename: dqn_pong_t0_s0 +spec_name: dqn_pong +experiment_ts: 2018_12_02_082510 +ckpt: ckpt-best of dqn_pong_t0_s0_ckpt-best if available

+
+ +
+
+convlab.lib.util.prepath_to_idxs(prepath)[source]
+

Extract trial index and session index from prepath if available

+
+ +
+
+convlab.lib.util.prepath_to_spec(prepath)[source]
+

Given a prepath, read the correct spec recover the meta_spec that will return the same prepath for eval lab modes +example: output/a2c_cartpole_2018_06_13_220436/a2c_cartpole_t0_s0

+
+ +
+
+convlab.lib.util.preprocess_image(im)[source]
+

Image preprocessing using OpenAI Baselines method: grayscale, resize +This resize uses stretching instead of cropping

+
+ +
+
+convlab.lib.util.read(data_path, **kwargs)[source]
+

Universal data reading method with smart data parsing +- {.csv} to DataFrame +- {.json} to dict, list +- {.yml} to dict +- {*} to str +@param {str} data_path The data path to read from +@returns {data} The read data in sensible format +@example

+

data_df = util.read(‘test/fixture/lib/util/test_df.csv’) +# => <DataFrame>

+

data_dict = util.read(‘test/fixture/lib/util/test_dict.json’) +data_dict = util.read(‘test/fixture/lib/util/test_dict.yml’) +# => <dict>

+

data_list = util.read(‘test/fixture/lib/util/test_list.json’) +# => <list>

+

data_str = util.read(‘test/fixture/lib/util/test_str.txt’) +# => <str>

+
+ +
+
+convlab.lib.util.read_as_df(data_path, **kwargs)[source]
+

Submethod to read data as DataFrame

+
+ +
+
+convlab.lib.util.read_as_pickle(data_path, **kwargs)[source]
+

Submethod to read data as pickle

+
+ +
+
+convlab.lib.util.read_as_plain(data_path, **kwargs)[source]
+

Submethod to read data as plain type

+
+ +
+
+convlab.lib.util.resize_image(im, w_h)[source]
+
+ +
+
+convlab.lib.util.run_cmd(cmd)[source]
+

Run shell command

+
+ +
+
+convlab.lib.util.run_cmd_wait(proc)[source]
+

Wait on a running process created by util.run_cmd and print its stdout

+
+ +
+
+convlab.lib.util.self_desc(cls)[source]
+

Method to get self description, used at init.

+
+ +
+
+convlab.lib.util.set_attr(obj, attr_dict, keys=None)[source]
+

Set attribute of an object from a dict

+
+ +
+
+convlab.lib.util.set_cuda_id(spec)[source]
+

Use trial and session id to hash and modulo cuda device count for a cuda_id to maximize device usage. Sets the net_spec for the base Net class to pick up.

+
+ +
+
+convlab.lib.util.set_logger(spec, logger, unit=None)[source]
+

Set the logger for a lab unit give its spec

+
+ +
+
+convlab.lib.util.set_random_seed(spec)[source]
+

Generate and set random seed for relevant modules, and record it in spec.meta.random_seed

+
+ +
+
+convlab.lib.util.sizeof(obj, divisor=1000000.0)[source]
+

Return the size of object, in MB by default

+
+ +
+
+convlab.lib.util.smart_path(data_path, as_dir=False)[source]
+

Resolve data_path into abspath with fallback to join from ROOT_DIR +@param {str} data_path The input data path to resolve +@param {bool} as_dir Whether to return as dirname +@returns {str} The normalized absolute data_path +@example

+

util.smart_path(‘convlab/lib’) +# => ‘/Users/ANON/Documents/convlab/convlab/lib’

+

util.smart_path(‘/tmp’) +# => ‘/tmp’

+
+ +
+
+convlab.lib.util.split_minibatch(batch, mb_size)[source]
+

Split a batch into minibatches of mb_size or smaller, without replacement

+
+ +
+
+convlab.lib.util.to_json(d, indent=2)[source]
+

Shorthand method for stringify JSON with indent

+
+ +
+
+convlab.lib.util.to_opencv_image(im)[source]
+

Convert to OpenCV image shape h,w,c

+
+ +
+
+convlab.lib.util.to_pytorch_image(im)[source]
+

Convert to PyTorch image shape c,h,w

+
+ +
+
+convlab.lib.util.to_render()[source]
+
+ +
+
+convlab.lib.util.to_torch_batch(batch, device, is_episodic)[source]
+

Mutate a batch (dict) to make its values from numpy into PyTorch tensor

+
+ +
+
+convlab.lib.util.write(data, data_path)[source]
+

Universal data writing method with smart data parsing +- {.csv} from DataFrame +- {.json} from dict, list +- {.yml} from dict +- {*} from str(*) +@param {*} data The data to write +@param {str} data_path The data path to write to +@returns {data_path} The data path written to +@example

+

data_path = util.write(data_df, ‘test/fixture/lib/util/test_df.csv’)

+

data_path = util.write(data_dict, ‘test/fixture/lib/util/test_dict.json’) +data_path = util.write(data_dict, ‘test/fixture/lib/util/test_dict.yml’)

+

data_path = util.write(data_list, ‘test/fixture/lib/util/test_list.json’)

+

data_path = util.write(data_str, ‘test/fixture/lib/util/test_str.txt’)

+
+ +
+
+convlab.lib.util.write_as_df(data, data_path)[source]
+

Submethod to write data as DataFrame

+
+ +
+
+convlab.lib.util.write_as_pickle(data, data_path)[source]
+

Submethod to write data as pickle

+
+ +
+
+convlab.lib.util.write_as_plain(data, data_path)[source]
+

Submethod to write data as plain type

+
+ +
+
+

convlab.lib.viz module

+
+
+convlab.lib.viz.create_label(y_col, x_col, title=None, y_title=None, x_title=None, legend_name=None)[source]
+

Create label dict for go.Layout with smart resolution

+
+ +
+
+convlab.lib.viz.create_layout(title, y_title, x_title, x_type=None, width=500, height=500, layout_kwargs=None)[source]
+

simplified method to generate Layout

+
+ +
+
+convlab.lib.viz.get_palette(size)[source]
+

Get the suitable palette of a certain size

+
+ +
+
+convlab.lib.viz.lower_opacity(rgb, opacity)[source]
+
+ +
+
+convlab.lib.viz.plot(*args, **kwargs)[source]
+
+ +
+
+convlab.lib.viz.plot_experiment(experiment_spec, experiment_df, metrics_cols)[source]
+

Plot the metrics vs. specs parameters of an experiment, where each point is a trial. +ref colors: https://plot.ly/python/heatmaps-contours-and-2dhistograms-tutorial/#plotlys-predefined-color-scales

+
+ +
+
+convlab.lib.viz.plot_mean_sr(sr_list, time_sr, title, y_title, x_title)[source]
+

Plot a list of series using its mean, with error bar using std

+
+ +
+
+convlab.lib.viz.plot_session(session_spec, session_metrics, session_df, df_mode='eval')[source]
+

Plot the session graphs: +- mean_returns, strengths, sample_efficiencies, training_efficiencies, stabilities (with error bar) +- additional plots from session_df: losses, exploration variable, entropy

+
+ +
+
+convlab.lib.viz.plot_sr(sr, time_sr, title, y_title, x_title)[source]
+

Plot a series

+
+ +
+
+convlab.lib.viz.plot_trial(trial_spec, trial_metrics)[source]
+

Plot the trial graphs: +- mean_returns, strengths, sample_efficiencies, training_efficiencies, stabilities (with error bar) +- consistencies (no error bar)

+
+ +
+
+convlab.lib.viz.save_image(figure, filepath)[source]
+
+ +
+
+

Module contents

+
+
+ + +
+ +
+ + +
+
+ +
+ +
+ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/build/html/tasktk.dialog_agent.html b/docs/build/html/convlab.modules.action_decoder.html similarity index 51% rename from docs/build/html/tasktk.dialog_agent.html rename to docs/build/html/convlab.modules.action_decoder.html index 24d5fda..a44a796 100644 --- a/docs/build/html/tasktk.dialog_agent.html +++ b/docs/build/html/convlab.modules.action_decoder.html @@ -8,7 +8,7 @@ - tasktk.dialog_agent package — Tasktk documentation + convlab.modules.action_decoder package — ConvLab 0.1 documentation @@ -24,6 +24,7 @@ + @@ -40,15 +41,14 @@
-