Skip to content
This repository has been archived by the owner on May 5, 2023. It is now read-only.

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbricman committed Apr 8, 2021
1 parent 6c94322 commit 241f545
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
Binary file not shown.
44 changes: 22 additions & 22 deletions file-structure-replica/.obsidian/plugins/Dual/skeleton/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def __init__(self, root_dir):
self.root_dir = root_dir
self.cache_address = os.path.join(root_dir, '.obsidian/plugins/Dual/skeleton/cache.pickle')
self.entry_regex = os.path.join(root_dir, '*md')
self.auxiliary_models_ready = False
self.main_model_ready = False
self.skeleton_ready = False
self.essence_ready = False

self.load_auxiliary_models()
self.load_main_model()
self.load_skeleton()
self.load_essence()

if os.path.isfile(self.cache_address) is False:
self.create_cache()
Expand All @@ -28,10 +28,10 @@ def __init__(self, root_dir):
self.sync_cache()

def fluid_search(self, query, considered_candidates=50, selected_candidates=5, second_pass=True):
self.load_main_model()
self.load_essence()

if self.main_model_ready == False:
return ['The aligned model is unavailable at the required location.']
if self.essence_ready == False:
return ['The essence is not present at the required location.']

self.sync_cache()
selected_candidates = min(selected_candidates, considered_candidates)
Expand All @@ -50,10 +50,10 @@ def fluid_search(self, query, considered_candidates=50, selected_candidates=5, s
return [self.entry_filenames[hit['corpus_id']] for hit in hits[:selected_candidates]]

def descriptive_search(self, claim, polarity=True, target='premise', considered_candidates=50, selected_candidates=5):
self.load_main_model()
self.load_essence()

if self.main_model_ready == False:
return ['The aligned model is unavailable at the required location.']
if self.essence_ready == False:
return ['The essence is not present at the required location.']

selected_candidates = min(selected_candidates, considered_candidates)
considered_candidates = min(considered_candidates, len(self.entry_filenames))
Expand All @@ -79,10 +79,10 @@ def descriptive_search(self, claim, polarity=True, target='premise', considered_
return results

def open_dialogue(self, question, considered_candidates=3):
self.load_main_model()
self.load_essence()

if self.main_model_ready == False:
return ['The aligned model is unavailable at the required location.']
if self.essence_ready == False:
return ['The essence is not present at the required location.']

candidate_entry_filenames = self.fluid_search(question, selected_candidates=considered_candidates)
candidate_entry_contents = reversed([self.entries[e][0] for e in candidate_entry_filenames])
Expand All @@ -95,33 +95,33 @@ def open_dialogue(self, question, considered_candidates=3):
max_length=len(input_ids[0]) + 100,
top_p=0.9,
top_k=40,
temperature=1
temperature=0.9
)

output_sample = self.gen_tokenizer.decode(generator_output[0], skip_special_tokens=True)[len(generator_prompt):]
output_sample = re.sub(r'^[\W_]+|[\W_]+$', '', output_sample)
output_sample = re.sub(r'[^a-zA-Z0-9\s]{3,}', '', output_sample)
output_sample = output_sample.split('Q:')[0].strip()
output_sample = output_sample.split('Q:')[0].split('\n\n')[0].strip()
output_sample += '...'

return [output_sample]

def load_auxiliary_models(self):
print('Loading auxiliary models...')
def load_skeleton(self):
print('Loading skeleton...')
self.text_encoder = SentenceTransformer('msmarco-distilbert-base-v2')
self.pair_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-4')
self.nli = CrossEncoder('cross-encoder/nli-distilroberta-base')
self.auxiliary_models_ready = True
self.skeleton_ready = True

def load_main_model(self):
def load_essence(self):
tentative_folder_path = os.path.join(self.root_dir, '.obsidian/plugins/Dual/essence')
tentative_file_path = os.path.join(tentative_folder_path, 'pytorch_model.bin')

if self.main_model_ready == False and os.path.isfile(tentative_file_path):
print('Loading main model...')
if self.essence_ready == False and os.path.isfile(tentative_file_path):
print('Loading essence...')
self.gen_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
self.gen_model = GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path=tentative_folder_path, pad_token_id=self.gen_tokenizer.eos_token_id)
self.main_model_ready = True
self.essence_ready = True

def copy_snapshot(self):
return {
Expand Down

0 comments on commit 241f545

Please sign in to comment.