Skip to content

Commit

Permalink
Fix permission
Browse files Browse the repository at this point in the history
  • Loading branch information
m3hrdadfi committed Aug 12, 2021
1 parent 21b8f34 commit 227ce59
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions apps/st_app/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import streamlit as st

import os
import torch
from transformers import pipeline
from transformers import AutoConfig, AutoTokenizer, AutoModelForTokenClassification
Expand All @@ -16,6 +16,7 @@
"Persian (fa)": "m3hrdadfi/typo-detector-distilbert-fa",
"Icelandic (is)": "m3hrdadfi/typo-detector-distilbert-is",
}
API_TOKEN = os.environ.get("API_TOKEN")


class TypoDetector:
Expand All @@ -34,11 +35,15 @@ def __init__(
self.nlp = None
self.normalizer = None

def load(self):
def load(self, api_token=None):
api_token = api_token if api_token else False
if not self.debug:
self.config = AutoConfig.from_pretrained(self.model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.model = AutoModelForTokenClassification.from_pretrained(self.model_name_or_path, config=self.config)
self.config = AutoConfig.from_pretrained(self.model_name_or_path, use_auth_token=api_token)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_auth_token=api_token)
self.model = AutoModelForTokenClassification.from_pretrained(
self.model_name_or_path,
config=self.config,
use_auth_token=api_token)
self.nlp = pipeline(
self.task_name,
model=self.model,
Expand Down Expand Up @@ -70,7 +75,7 @@ def load_typo_detectors():
is_detector.load()

fa_detector = TypoDetector(MODELS["Persian (fa)"])
fa_detector.load()
fa_detector.load(api_token=API_TOKEN)

return {
"en": en_detector,
Expand Down

0 comments on commit 227ce59

Please sign in to comment.