From 7cc33d07011b4f8400b418f4ffbb295ec342db47 Mon Sep 17 00:00:00 2001 From: virusdefender Date: Fri, 24 Nov 2017 22:26:56 +0800 Subject: [PATCH] use bulk_create and transcation for importing user --- account/tests.py | 17 +++++++++++------ account/views/admin.py | 38 +++++++++++++++++++++----------------- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/account/tests.py b/account/tests.py index 4279c96c..bf537273 100644 --- a/account/tests.py +++ b/account/tests.py @@ -557,20 +557,25 @@ class AdminUserTest(APITestCase): def test_import_users(self): data = {"users": [["user1", "pass1", "eami1@e.com"], - ["user1", "pass1", "eami1@e.com"], - ["user2", "pass2"], ["user3", "pass3", "eamil3@e.com"]] + ["user2", "pass3", "eamil3@e.com"]] } resp = self.client.post(self.url, data) self.assertSuccess(resp) - self.assertDictEqual(resp.data["data"], {"omitted_count": 1, - "created_count": 2, - "get_count": 1}) # successfully created 2 users self.assertEqual(User.objects.all().count(), 4) + def test_import_duplicate_user(self): + data = {"users": [["user1", "pass1", "eami1@e.com"], + ["user1", "pass1", "eami1@e.com"]] + } + resp = self.client.post(self.url, data) + self.assertFailed(resp, "DETAIL: Key (username)=(user1) already exists.") + # no user is created + self.assertEqual(User.objects.all().count(), 2) + def test_delete_users(self): self.test_import_users() - user_ids = User.objects.filter(username__in=["user1", "user3"]).values_list("id", flat=True) + user_ids = User.objects.filter(username__in=["user1", "user2"]).values_list("id", flat=True) user_ids = ",".join([str(id) for id in user_ids]) resp = self.client.delete(self.url + "?id=" + user_ids) self.assertSuccess(resp) diff --git a/account/views/admin.py b/account/views/admin.py index 062fdcf1..f0b93c65 100644 --- a/account/views/admin.py +++ b/account/views/admin.py @@ -1,8 +1,11 @@ import os import re import xlsxwriter + +from django.db import transaction, IntegrityError from django.db.models import Q from django.http import HttpResponse +from django.contrib.auth.hashers import make_password from submission.models import Submission from utils.api import APIView, validate_serializer @@ -18,26 +21,27 @@ class UserAdminAPI(APIView): @validate_serializer(ImportUserSeralizer) @super_admin_required def post(self, request): + """ + Generate user + """ data = request.data["users"] - omitted_count = created_count = get_count = 0 + + user_list = [] for user_data in data: if len(user_data) != 3 or len(user_data[0]) > 32: - omitted_count += 1 - continue - user, created = User.objects.get_or_create(username=user_data[0]) - user.set_password(user_data[1]) - user.email = user_data[2] - user.save() - if created: - UserProfile.objects.create(user=user) - created_count += 1 - else: - get_count += 1 - return self.success({ - "omitted_count": omitted_count, - "created_count": created_count, - "get_count": get_count - }) + return self.error(f"Error occurred while processing data '{user_data}'") + user_list.append(User(username=user_data[0], password=make_password(user_data[1]), email=user_data[2])) + + try: + with transaction.atomic(): + ret = User.objects.bulk_create(user_list) + UserProfile.objects.bulk_create([UserProfile(user=user) for user in ret]) + return self.success() + except IntegrityError as e: + # Extract detail from exception message + # duplicate key value violates unique constraint "user_username_key" + # DETAIL: Key (username)=(root11) already exists. + return self.error(str(e).split("\n")[1]) @validate_serializer(EditUserSerializer) @super_admin_required