From a87d73393d6df10703708db5f8abee4ed34ebff8 Mon Sep 17 00:00:00 2001 From: zema1 Date: Thu, 23 Nov 2017 21:12:37 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=A8account=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .travis.yml | 2 +- account/tests.py | 192 +++++++++++++++++++++++++++++++++-------- account/urls/oj.py | 2 +- account/views/admin.py | 2 +- account/views/oj.py | 2 +- run_test.py | 2 +- 6 files changed, 162 insertions(+), 40 deletions(-) diff --git a/.travis.yml b/.travis.yml index 25b3ba13..7782d02c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,7 +17,7 @@ install: script: - docker ps -a - flake8 . - - coverage run --source='.' manage.py test + - coverage run --include="$PWD/*" manage.py test - coverage report notifications: slack: onlinejudgeteam:BzBz8UFgmS5crpiblof17K2W diff --git a/account/tests.py b/account/tests.py index ccc8b1a8..4279c96c 100644 --- a/account/tests.py +++ b/account/tests.py @@ -1,6 +1,7 @@ import time from unittest import mock from datetime import timedelta +from copy import deepcopy from django.contrib import auth from django.utils.timezone import now @@ -34,18 +35,32 @@ class PermissionDecoratorTest(APITestCase): class DuplicateUserCheckAPITest(APITestCase): def setUp(self): - self.create_user("test", "test123", login=False) + user = self.create_user("test", "test123", login=False) + user.email = "test@test.com" + user.save() self.url = self.reverse("check_username_or_email") def test_duplicate_username(self): resp = self.client.post(self.url, data={"username": "test"}) data = resp.data["data"] self.assertEqual(data["username"], True) + resp = self.client.post(self.url, data={"username": "Test"}) + self.assertEqual(resp.data["data"]["username"], True) def test_ok_username(self): resp = self.client.post(self.url, data={"username": "test1"}) data = resp.data["data"] - self.assertEqual(data["username"], False) + self.assertFalse(data["username"]) + + def test_duplicate_email(self): + resp = self.client.post(self.url, data={"email": "test@test.com"}) + self.assertEqual(resp.data["data"]["email"], True) + resp = self.client.post(self.url, data={"email": "Test@Test.com"}) + self.assertTrue(resp.data["data"]["email"]) + + def test_ok_email(self): + resp = self.client.post(self.url, data={"email": "aa@test.com"}) + self.assertFalse(resp.data["data"]["email"]) class TFARequiredCheckAPITest(APITestCase): @@ -87,6 +102,12 @@ class UserLoginAPITest(APITestCase): user = auth.get_user(self.client) self.assertTrue(user.is_authenticated()) + def test_login_with_correct_info_upper_username(self): + resp = self.client.post(self.login_url, data={"username": self.username.upper(), "password": self.password}) + self.assertDictEqual(resp.data, {"error": None, "data": "Succeeded"}) + user = auth.get_user(self.client) + self.assertTrue(user.is_authenticated()) + def test_login_with_wrong_info(self): response = self.client.post(self.login_url, data={"username": self.username, "password": "invalid_password"}) @@ -346,19 +367,48 @@ class ResetPasswordAPITest(CaptchaTest): self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"}) -class UserChangePasswordAPITest(CaptchaTest): +class UserChangeEmailAPITest(APITestCase): + def setUp(self): + self.url = self.reverse("user_change_email_api") + self.user = self.create_user("test", "test123") + self.new_mail = "test@oj.com" + self.data = {"password": "test123", "new_email": self.new_mail} + + def test_change_email_success(self): + resp = self.client.post(self.url, data=self.data) + self.assertSuccess(resp) + + def test_wrong_password(self): + self.data["password"] = "aaaa" + resp = self.client.post(self.url, data=self.data) + self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"}) + + def test_duplicate_email(self): + u = self.create_user("aa", "bb", login=False) + u.email = self.new_mail + u.save() + resp = self.client.post(self.url, data=self.data) + self.assertDictEqual(resp.data, {"error": "error", "data": "The email is owned by other account"}) + + +class UserChangePasswordAPITest(APITestCase): def setUp(self): - self.client = APIClient() self.url = self.reverse("user_change_password_api") # Create user at first self.username = "test_user" self.old_password = "testuserpassword" self.new_password = "new_password" - self.create_user(username=self.username, password=self.old_password, login=False) + self.user = self.create_user(username=self.username, password=self.old_password, login=False) - self.data = {"old_password": self.old_password, "new_password": self.new_password, - "captcha": self._set_captcha(self.client.session)} + self.data = {"old_password": self.old_password, "new_password": self.new_password} + + def _get_tfa_code(self): + user = User.objects.first() + code = OtpAuth(user.tfa_token).totp() + if len(str(code)) < 6: + code = (6 - len(str(code))) * "0" + str(code) + return code def test_login_required(self): response = self.client.post(self.url, data=self.data) @@ -376,6 +426,58 @@ class UserChangePasswordAPITest(CaptchaTest): response = self.client.post(self.url, data=self.data) self.assertEqual(response.data, {"error": "error", "data": "Invalid old password"}) + def test_tfa_code_required(self): + self.user.two_factor_auth = True + self.user.tfa_token = "tfa_token" + self.user.save() + self.assertTrue(self.client.login(username=self.username, password=self.old_password)) + self.data["tfa_code"] = rand_str(6) + resp = self.client.post(self.url, data=self.data) + self.assertEqual(resp.data, {"error": "error", "data": "Invalid two factor verification code"}) + + self.data["tfa_code"] = self._get_tfa_code() + resp = self.client.post(self.url, data=self.data) + self.assertSuccess(resp) + + +class UserRankAPITest(APITestCase): + def setUp(self): + self.url = self.reverse("user_rank_api") + self.create_user("test1", "test123", login=False) + self.create_user("test2", "test123", login=False) + test1 = User.objects.get(username="test1") + profile1 = test1.userprofile + profile1.submission_number = 10 + profile1.accepted_number = 10 + profile1.total_score = 240 + profile1.save() + + test2 = User.objects.get(username="test2") + profile2 = test2.userprofile + profile2.submission_number = 15 + profile2.accepted_number = 10 + profile2.total_score = 700 + profile2.save() + + def test_get_acm_rank(self): + resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM}) + self.assertSuccess(resp) + data = resp.data["data"]["results"] + self.assertEqual(data[0]["user"]["username"], "test1") + self.assertEqual(data[1]["user"]["username"], "test2") + + def test_get_oi_rank(self): + resp = self.client.get(self.url, data={"rule": ContestRuleType.OI}) + self.assertSuccess(resp) + data = resp.data["data"]["results"] + self.assertEqual(data[0]["user"]["username"], "test2") + self.assertEqual(data[1]["user"]["username"], "test1") + + +class ProfileProblemDisplayIDRefreshAPITest(APITestCase): + def setUp(self): + pass + class AdminUserTest(APITestCase): def setUp(self): @@ -453,36 +555,56 @@ class AdminUserTest(APITestCase): self.assertTrue(resp_data["open_api"]) self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key) + def test_import_users(self): + data = {"users": [["user1", "pass1", "eami1@e.com"], + ["user1", "pass1", "eami1@e.com"], + ["user2", "pass2"], ["user3", "pass3", "eamil3@e.com"]] + } + resp = self.client.post(self.url, data) + self.assertSuccess(resp) + self.assertDictEqual(resp.data["data"], {"omitted_count": 1, + "created_count": 2, + "get_count": 1}) + # successfully created 2 users + self.assertEqual(User.objects.all().count(), 4) -class UserRankAPITest(APITestCase): + def test_delete_users(self): + self.test_import_users() + user_ids = User.objects.filter(username__in=["user1", "user3"]).values_list("id", flat=True) + user_ids = ",".join([str(id) for id in user_ids]) + resp = self.client.delete(self.url + "?id=" + user_ids) + self.assertSuccess(resp) + self.assertEqual(User.objects.all().count(), 2) + + +class GenerateUserAPITest(APITestCase): def setUp(self): - self.url = self.reverse("user_rank_api") - self.create_user("test1", "test123", login=False) - self.create_user("test2", "test123", login=False) - test1 = User.objects.get(username="test1") - profile1 = test1.userprofile - profile1.submission_number = 10 - profile1.accepted_number = 10 - profile1.total_score = 240 - profile1.save() + self.create_super_admin() + self.url = self.reverse("generate_user_api") + self.data = { + "number_from": 100, "number_to": 105, + "prefix": "pre", "suffix": "suf", + "default_email": "test@test.com", + "password_length": 8 + } - test2 = User.objects.get(username="test2") - profile2 = test2.userprofile - profile2.submission_number = 15 - profile2.accepted_number = 10 - profile2.total_score = 700 - profile2.save() + def test_error_case(self): + data = deepcopy(self.data) + data["prefix"] = "t" * 16 + data["suffix"] = "s" * 14 + resp = self.client.post(self.url, data=data) + self.assertEqual(resp.data["data"], "Username should not more than 32 characters") - def test_get_acm_rank(self): - resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM}) + data2 = deepcopy(self.data) + data2["number_from"] = 106 + resp = self.client.post(self.url, data=data2) + self.assertEqual(resp.data["data"], "Start number must be lower than end number") + + @mock.patch("account.views.admin.xlsxwriter.Workbook") + def test_generate_user_success(self, mock_workbook): + resp = self.client.post(self.url, data=self.data) self.assertSuccess(resp) - data = resp.data["data"]["results"] - self.assertEqual(data[0]["user"]["username"], "test1") - self.assertEqual(data[1]["user"]["username"], "test2") - - def test_get_oi_rank(self): - resp = self.client.get(self.url, data={"rule": ContestRuleType.OI}) - self.assertSuccess(resp) - data = resp.data["data"]["results"] - self.assertEqual(data[0]["user"]["username"], "test2") - self.assertEqual(data[1]["user"]["username"], "test1") + mock_workbook.assert_called() + data = resp.data["data"] + self.assertEqual(data["created_count"], 6) + self.assertEqual(data["get_count"], 0) diff --git a/account/urls/oj.py b/account/urls/oj.py index c1915f49..a92dd5be 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -14,7 +14,7 @@ urlpatterns = [ url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"), url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"), url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"), - url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email"), + url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email_api"), url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"), url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"), url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"), diff --git a/account/views/admin.py b/account/views/admin.py index b25ccb44..ad731b5a 100644 --- a/account/views/admin.py +++ b/account/views/admin.py @@ -129,7 +129,7 @@ class UserAdminAPI(APIView): def delete(self, request): id = request.GET.get("id") if not id: - return self.error("Invalid Parameter, user_id is required") + return self.error("Invalid Parameter, id is required") for user_id in id.split(","): if user_id: error = self.delete_one(user_id) diff --git a/account/views/oj.py b/account/views/oj.py index 4af6b824..344a4196 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -197,7 +197,7 @@ class UsernameOrEmailCheck(APIView): "email": False } if data.get("username"): - result["username"] = User.objects.filter(username=data["username"]).exists() + result["username"] = User.objects.filter(username=data["username"].lower()).exists() if data.get("email"): result["email"] = User.objects.filter(email=data["email"].lower()).exists() return self.success(result) diff --git a/run_test.py b/run_test.py index 2613ca75..cb4c6300 100644 --- a/run_test.py +++ b/run_test.py @@ -21,7 +21,7 @@ print("running flake8...") if os.system("flake8 --statistics ."): exit() -ret = os.system("coverage run --source='.' ./manage.py test {module} --settings={setting}".format(module=test_module, setting=setting)) +ret = os.system("coverage run --include=\"$PWD/*\" manage.py test {module} --settings={setting}".format(module=test_module, setting=setting)) if not ret and is_coverage: os.system("coverage html && open htmlcov/index.html")