diff --git a/leetcode/auth.py b/leetcode/auth.py index df13f57..18a40af 100644 --- a/leetcode/auth.py +++ b/leetcode/auth.py @@ -71,6 +71,11 @@ def is_login(): return 'user_name' in data and data['user_name'] != '' +def ensure_login(): + if not is_login(): + login() + + def retrieve(url, headers=None, method='GET', data=None): try: if method == 'GET': diff --git a/leetcode/leetcode.py b/leetcode/leetcode.py index 7869810..bbb3f4d 100644 --- a/leetcode/leetcode.py +++ b/leetcode/leetcode.py @@ -4,7 +4,7 @@ import logging from bs4 import BeautifulSoup from .config import config -from .auth import session, headers, retrieve +from .auth import session, headers, retrieve, ensure_login from .code import * BASE_URL = 'https://leetcode.com' @@ -30,6 +30,12 @@ def merge_two_dicts(x, y): 'Go': 'go', } + +def safe_retrieve(url, headers=None, method='GET', data=None): + ensure_login() + return retrieve(url, headers, method, data) + + class Quiz(object): def __init__(self): self.id = None @@ -48,7 +54,7 @@ def __init__(self): self.logger = logging.getLogger(__name__) def load(self): - r = retrieve(self.url) + r = safe_retrieve(self.url) if r.status_code != 200: return False text = r.text.encode('utf-8') @@ -56,8 +62,8 @@ def load(self): bs = BeautifulSoup(text, 'lxml') if bs.find('form', 'form-signin'): - self.session.cookies.clear() - r = retrieve(BASE_URL + item.url) + session.cookies.clear() + r = safe_retrieve(BASE_URL + item.url) try: content = bs.find('div', 'question-description') @@ -90,6 +96,9 @@ def load(self): return False def submit(self, code): + # Call this upfront so we have the right CSRF + ensure_login() + body = { 'question_id': self.id, 'test_mode': False, 'lang': LANG_MAPPING.get(config.language, 'cpp'), @@ -109,7 +118,7 @@ def submit(self, code): 'X-CSRFToken': csrftoken, 'X-Requested-With': 'XMLHttpRequest'}) - r = retrieve(self.url + '/submit/', method='POST', data=json.dumps(body), headers=newheaders) + r = safe_retrieve(self.url + '/submit/', method='POST', data=json.dumps(body), headers=newheaders) if r.status_code != 200: return (False, 'Request failed!') text = r.text.encode('utf-8') @@ -124,7 +133,7 @@ def submit(self, code): def check_submission_result(self, submission_id): url = SUBMISSION_URL.format(id=submission_id) - r = retrieve(url) + r = safe_retrieve(url) if r.status_code != 200: return (-100, 'Request failed!') text = r.text.encode('utf-8') @@ -161,7 +170,7 @@ def solved(self): return [i for i in self.quizzes if i.submission_status == 'ac'] def load(self): - r = retrieve(API_URL) + r = safe_retrieve(API_URL) if r.status_code != 200: return None text = r.text.encode('utf-8') diff --git a/tests/test_leetcode.py b/tests/test_leetcode.py index 15e7961..4524326 100644 --- a/tests/test_leetcode.py +++ b/tests/test_leetcode.py @@ -1,7 +1,7 @@ import unittest import mock from leetcode.leetcode import * -from leetcode.auth import NetworkError +from leetcode.auth import NetworkError, LOGIN_URL, API_URL import requests_mock import requests @@ -66,6 +66,7 @@ def test_retrieve_home(self): with requests_mock.Mocker() as m: m.get(API_URL, status_code=403) + m.get(LOGIN_URL, status_code=403) self.assertIsNone(self.leet.load()) m.get(API_URL, json={"error": "not found"}) self.assertIsNone(self.leet.load()) @@ -119,6 +120,7 @@ def test_retrieve_detail(self): item.submission_status = 'ac' with requests_mock.Mocker() as m: + m.get(API_URL, json={'user_name': 'fake'}) m.get('http://hello.com', status_code=403) self.assertFalse(item.load()) m.get('http://hello.com', text=data) @@ -142,6 +144,7 @@ def test_submit_code(self): item.submission_status = 'ac' with requests_mock.Mocker() as m: + m.get(API_URL, json={'user_name': 'fake'}) m.post(item.url + '/submit/', status_code=402) self.assertFalse(item.submit('code')[0]) @@ -162,6 +165,7 @@ def test_submission_result(self): item.submission_status = 'ac' url = SUBMISSION_URL.format(id=1) with requests_mock.Mocker() as m: + m.get(API_URL, json={'user_name': 'fake'}) m.get(url, status_code=403) self.assertEqual(item.check_submission_result(1)[0], -100) m.get(url, json={ 'state': 'PENDING' })