mirror of
https://github.com/QingdaoU/OnlineJudge.git
synced 2024-09-21 08:23:20 +00:00
使用Python3和更科学的API写法
This commit is contained in:
parent
d9b1141cb9
commit
172fd4b1f4
@ -1,13 +1,10 @@
|
|||||||
# coding=utf-8
|
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import urllib
|
|
||||||
import json
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
|
|
||||||
from utils.shortcuts import JSONResponse
|
from utils.api import JSONResponse
|
||||||
from .models import AdminType
|
from .models import AdminType
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +16,7 @@ class BasePermissionDecorator(object):
|
|||||||
return functools.partial(self.__call__, obj)
|
return functools.partial(self.__call__, obj)
|
||||||
|
|
||||||
def error(self, data):
|
def error(self, data):
|
||||||
return JSONResponse({"error": "permission-denied", "data": data})
|
return JSONResponse.response({"error": "permission-denied", "data": data})
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
self.request = args[1]
|
self.request = args[1]
|
||||||
|
@ -1,12 +1,8 @@
|
|||||||
# coding=utf-8
|
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
|
|
||||||
from django.http import HttpResponse
|
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
from django.contrib import auth
|
from django.contrib import auth
|
||||||
|
|
||||||
from utils.shortcuts import JSONResponse
|
from utils.api import JSONResponse
|
||||||
from .models import AdminType
|
from .models import AdminType
|
||||||
|
|
||||||
|
|
||||||
@ -17,7 +13,7 @@ class SessionSecurityMiddleware(object):
|
|||||||
# 24 hours passed since last visit
|
# 24 hours passed since last visit
|
||||||
if time.time() - request.session["last_activity"] >= 24 * 60 * 60:
|
if time.time() - request.session["last_activity"] >= 24 * 60 * 60:
|
||||||
auth.logout(request)
|
auth.logout(request)
|
||||||
return JSONResponse({"error": "login-required", "data": _("Please login in first")})
|
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
|
||||||
# update last active time
|
# update last active time
|
||||||
request.session["last_activity"] = time.time()
|
request.session["last_activity"] = time.time()
|
||||||
|
|
||||||
@ -27,4 +23,4 @@ class AdminRequiredMiddleware(object):
|
|||||||
path = request.path_info
|
path = request.path_info
|
||||||
if path.startswith("/admin/") or path.startswith("/api/admin/"):
|
if path.startswith("/admin/") or path.startswith("/api/admin/"):
|
||||||
if not(request.user.is_authenticated() and request.user.is_admin()):
|
if not(request.user.is_authenticated() and request.user.is_admin()):
|
||||||
return JSONResponse({"error": "login-required", "data": _("Please login in first")})
|
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
|
@ -1,5 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
from django.contrib.auth.models import AbstractBaseUser
|
from django.contrib.auth.models import AbstractBaseUser
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from jsonfield import JSONField
|
from jsonfield import JSONField
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
from rest_framework import serializers
|
from utils.api import serializers, DateTimeTZField
|
||||||
|
|
||||||
from utils.serializers import DateTimeTZField
|
|
||||||
from .models import User, AdminType
|
from .models import User, AdminType
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,17 +1,13 @@
|
|||||||
# coding=utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import mock
|
|
||||||
from django.contrib import auth
|
from django.contrib import auth
|
||||||
from django.core.urlresolvers import reverse
|
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
from rest_framework.test import APIClient
|
|
||||||
|
|
||||||
from utils.otp_auth import OtpAuth
|
from utils.otp_auth import OtpAuth
|
||||||
from utils.shortcuts import rand_str
|
from utils.shortcuts import rand_str
|
||||||
from utils.tests import APITestCase
|
from utils.api.tests import APITestCase, APIClient
|
||||||
|
|
||||||
from .models import User, AdminType
|
from .models import User, AdminType
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +33,7 @@ class UserLoginAPITest(APITestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.username = self.password = "test"
|
self.username = self.password = "test"
|
||||||
self.user = self.create_user(username=self.username, password=self.password)
|
self.user = self.create_user(username=self.username, password=self.password)
|
||||||
self.login_url = reverse("user_login_api")
|
self.login_url = self.reverse("user_login_api")
|
||||||
|
|
||||||
def _set_tfa(self):
|
def _set_tfa(self):
|
||||||
self.user.two_factor_auth = True
|
self.user.two_factor_auth = True
|
||||||
@ -110,7 +106,7 @@ class CaptchaTest(APITestCase):
|
|||||||
class UserRegisterAPITest(CaptchaTest):
|
class UserRegisterAPITest(CaptchaTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.client = APIClient()
|
self.client = APIClient()
|
||||||
self.register_url = reverse("user_register_api")
|
self.register_url = self.reverse("user_register_api")
|
||||||
self.captcha = rand_str(4)
|
self.captcha = rand_str(4)
|
||||||
|
|
||||||
self.data = {"username": "test_user", "password": "testuserpassword",
|
self.data = {"username": "test_user", "password": "testuserpassword",
|
||||||
@ -150,7 +146,7 @@ class UserRegisterAPITest(CaptchaTest):
|
|||||||
class UserChangePasswordAPITest(CaptchaTest):
|
class UserChangePasswordAPITest(CaptchaTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.client = APIClient()
|
self.client = APIClient()
|
||||||
self.url = reverse("user_change_password_api")
|
self.url = self.reverse("user_change_password_api")
|
||||||
|
|
||||||
# Create user at first
|
# Create user at first
|
||||||
self.username = "test_user"
|
self.username = "test_user"
|
||||||
@ -183,7 +179,7 @@ class AdminUserTest(APITestCase):
|
|||||||
self.user = self.create_super_admin(login=True)
|
self.user = self.create_super_admin(login=True)
|
||||||
self.username = self.password = "test"
|
self.username = self.password = "test"
|
||||||
self.regular_user = self.create_user(username=self.username, password=self.password)
|
self.regular_user = self.create_user(username=self.username, password=self.password)
|
||||||
self.url = reverse("user_admin_api")
|
self.url = self.reverse("user_admin_api")
|
||||||
self.data = {"id": self.regular_user.id, "username": self.username, "real_name": "test_name",
|
self.data = {"id": self.regular_user.id, "username": self.username, "real_name": "test_name",
|
||||||
"email": "test@qq.com", "admin_type": AdminType.REGULAR_USER,
|
"email": "test@qq.com", "admin_type": AdminType.REGULAR_USER,
|
||||||
"open_api": True, "two_factor_auth": False, "is_disabled": False}
|
"open_api": True, "two_factor_auth": False, "is_disabled": False}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
from django.conf.urls import url
|
from django.conf.urls import url
|
||||||
|
|
||||||
from ..views.admin import UserAdminAPIView
|
from ..views.admin import UserAdminAPIView
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
from django.conf.urls import url
|
from django.conf.urls import url
|
||||||
|
|
||||||
from ..views.oj import UserLoginAPIView, UserRegisterAPIView, UserChangePasswordAPIView
|
from ..views.oj import UserLoginAPIView, UserRegisterAPIView, UserChangePasswordAPIView
|
||||||
|
@ -1,25 +1,25 @@
|
|||||||
# coding=utf-8
|
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from django.core.exceptions import MultipleObjectsReturned
|
from django.core.exceptions import MultipleObjectsReturned
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
|
|
||||||
from utils.shortcuts import (APIView, paginate_data, rand_str)
|
from utils.api import APIView, validate_serializer
|
||||||
|
from utils.shortcuts import rand_str
|
||||||
|
|
||||||
from ..decorators import super_admin_required
|
from ..decorators import super_admin_required
|
||||||
from ..models import User, AdminType
|
from ..models import User
|
||||||
from ..serializers import (UserSerializer, EditUserSerializer)
|
from ..serializers import (UserSerializer, EditUserSerializer)
|
||||||
|
|
||||||
|
|
||||||
class UserAdminAPIView(APIView):
|
class UserAdminAPIView(APIView):
|
||||||
|
@validate_serializer(EditUserSerializer)
|
||||||
@super_admin_required
|
@super_admin_required
|
||||||
def put(self, request):
|
def put(self, request):
|
||||||
"""
|
"""
|
||||||
Edit user api
|
Edit user api
|
||||||
"""
|
"""
|
||||||
serializer = EditUserSerializer(data=request.data)
|
data = request.data
|
||||||
if serializer.is_valid():
|
|
||||||
data = serializer.data
|
|
||||||
try:
|
try:
|
||||||
user = User.objects.get(id=data["id"])
|
user = User.objects.get(id=data["id"])
|
||||||
except User.DoesNotExist:
|
except User.DoesNotExist:
|
||||||
@ -68,8 +68,6 @@ class UserAdminAPIView(APIView):
|
|||||||
|
|
||||||
user.save()
|
user.save()
|
||||||
return self.success(UserSerializer(user).data)
|
return self.success(UserSerializer(user).data)
|
||||||
else:
|
|
||||||
return self.invalid_serializer(serializer)
|
|
||||||
|
|
||||||
@super_admin_required
|
@super_admin_required
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@ -97,4 +95,4 @@ class UserAdminAPIView(APIView):
|
|||||||
user = user.filter(Q(username__contains=keyword) |
|
user = user.filter(Q(username__contains=keyword) |
|
||||||
Q(real_name__contains=keyword) |
|
Q(real_name__contains=keyword) |
|
||||||
Q(email__contains=keyword))
|
Q(email__contains=keyword))
|
||||||
return self.success(paginate_data(request, user, UserSerializer))
|
return self.success(self.paginate_data(request, user, UserSerializer))
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
# coding=utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
from django.contrib import auth
|
from django.contrib import auth
|
||||||
from django.core.exceptions import MultipleObjectsReturned
|
from django.core.exceptions import MultipleObjectsReturned
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
|
|
||||||
|
from utils.api import APIView, validate_serializer
|
||||||
from utils.captcha import Captcha
|
from utils.captcha import Captcha
|
||||||
from utils.otp_auth import OtpAuth
|
from utils.otp_auth import OtpAuth
|
||||||
from utils.shortcuts import (APIView, )
|
|
||||||
from ..decorators import login_required
|
from ..decorators import login_required
|
||||||
from ..models import User, UserProfile
|
from ..models import User, UserProfile
|
||||||
from ..serializers import (UserLoginSerializer, UserRegisterSerializer,
|
from ..serializers import (UserLoginSerializer, UserRegisterSerializer,
|
||||||
@ -15,13 +12,12 @@ from ..serializers import (UserLoginSerializer, UserRegisterSerializer,
|
|||||||
|
|
||||||
|
|
||||||
class UserLoginAPIView(APIView):
|
class UserLoginAPIView(APIView):
|
||||||
|
@validate_serializer(UserLoginSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
"""
|
"""
|
||||||
User login api
|
User login api
|
||||||
"""
|
"""
|
||||||
serializer = UserLoginSerializer(data=request.data)
|
data = request.data
|
||||||
if serializer.is_valid():
|
|
||||||
data = serializer.data
|
|
||||||
user = auth.authenticate(username=data["username"], password=data["password"])
|
user = auth.authenticate(username=data["username"], password=data["password"])
|
||||||
# None is returned if username or password is wrong
|
# None is returned if username or password is wrong
|
||||||
if user:
|
if user:
|
||||||
@ -40,8 +36,6 @@ class UserLoginAPIView(APIView):
|
|||||||
return self.error(_("Invalid two factor verification code"))
|
return self.error(_("Invalid two factor verification code"))
|
||||||
else:
|
else:
|
||||||
return self.error(_("Invalid username or password"))
|
return self.error(_("Invalid username or password"))
|
||||||
else:
|
|
||||||
return self.invalid_serializer(serializer)
|
|
||||||
|
|
||||||
# todo remove this, only for debug use
|
# todo remove this, only for debug use
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@ -50,13 +44,12 @@ class UserLoginAPIView(APIView):
|
|||||||
|
|
||||||
|
|
||||||
class UserRegisterAPIView(APIView):
|
class UserRegisterAPIView(APIView):
|
||||||
|
@validate_serializer(UserRegisterSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
"""
|
"""
|
||||||
User register api
|
User register api
|
||||||
"""
|
"""
|
||||||
serializer = UserRegisterSerializer(data=request.data)
|
data = request.data
|
||||||
if serializer.is_valid():
|
|
||||||
data = serializer.data
|
|
||||||
captcha = Captcha(request)
|
captcha = Captcha(request)
|
||||||
if not captcha.check(data["captcha"]):
|
if not captcha.check(data["captcha"]):
|
||||||
return self.error(_("Invalid captcha"))
|
return self.error(_("Invalid captcha"))
|
||||||
@ -77,19 +70,16 @@ class UserRegisterAPIView(APIView):
|
|||||||
user.save()
|
user.save()
|
||||||
UserProfile.objects.create(user=user)
|
UserProfile.objects.create(user=user)
|
||||||
return self.success(_("Succeeded"))
|
return self.success(_("Succeeded"))
|
||||||
else:
|
|
||||||
return self.invalid_serializer(serializer)
|
|
||||||
|
|
||||||
|
|
||||||
class UserChangePasswordAPIView(APIView):
|
class UserChangePasswordAPIView(APIView):
|
||||||
|
@validate_serializer(UserChangePasswordSerializer)
|
||||||
@login_required
|
@login_required
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
"""
|
"""
|
||||||
User change password api
|
User change password api
|
||||||
"""
|
"""
|
||||||
serializer = UserChangePasswordSerializer(data=request.data)
|
data = request.data
|
||||||
if serializer.is_valid():
|
|
||||||
data = serializer.data
|
|
||||||
captcha = Captcha(request)
|
captcha = Captcha(request)
|
||||||
if not captcha.check(data["captcha"]):
|
if not captcha.check(data["captcha"]):
|
||||||
return self.error(_("Invalid captcha"))
|
return self.error(_("Invalid captcha"))
|
||||||
@ -101,5 +91,3 @@ class UserChangePasswordAPIView(APIView):
|
|||||||
return self.success(_("Succeeded"))
|
return self.success(_("Succeeded"))
|
||||||
else:
|
else:
|
||||||
return self.error(_("Invalid old password"))
|
return self.error(_("Invalid old password"))
|
||||||
else:
|
|
||||||
return self.invalid_serializer(serializer)
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
from account.models import User
|
from account.models import User
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
# coding=utf-8
|
from utils.api import serializers
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
from rest_framework import serializers
|
|
||||||
|
|
||||||
from account.models import User
|
from account.models import User
|
||||||
from utils.serializers import DateTimeTZField
|
from utils.api._serializers import DateTimeTZField
|
||||||
from .models import Announcement
|
from .models import Announcement
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
# coding=utf-8
|
from utils.api.tests import APITestCase, APIClient
|
||||||
from django.core.urlresolvers import reverse
|
|
||||||
from utils.tests import APITestCase
|
|
||||||
|
|
||||||
|
|
||||||
class AnnouncementAdminTest(APITestCase):
|
class AnnouncementAdminTest(APITestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.user = self.create_super_admin(login=True)
|
self.user = self.create_super_admin(login=True)
|
||||||
self.url = reverse("announcement_admin_api")
|
self.url = self.reverse("announcement_admin_api")
|
||||||
|
|
||||||
def test_announcement_list(self):
|
def test_announcement_list(self):
|
||||||
response = self.client.get(self.url)
|
response = self.client.get(self.url)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
from django.conf.urls import url
|
from django.conf.urls import url
|
||||||
|
|
||||||
from ..views import AnnouncementAdminAPIView
|
from ..views import AnnouncementAdminAPIView
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
# coding=utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
|
|
||||||
from account.decorators import super_admin_required
|
from account.decorators import super_admin_required
|
||||||
from utils.shortcuts import paginate_data, APIView
|
from utils.api import APIView
|
||||||
|
|
||||||
from .models import Announcement
|
from .models import Announcement
|
||||||
from .serializers import (CreateAnnouncementSerializer, AnnouncementSerializer,
|
from .serializers import (CreateAnnouncementSerializer, AnnouncementSerializer,
|
||||||
EditAnnouncementSerializer)
|
EditAnnouncementSerializer)
|
||||||
@ -63,4 +61,4 @@ class AnnouncementAdminAPIView(APIView):
|
|||||||
announcement = Announcement.objects.all().order_by("-create_time")
|
announcement = Announcement.objects.all().order_by("-create_time")
|
||||||
if request.GET.get("visible") == "true":
|
if request.GET.get("visible") == "true":
|
||||||
announcement = announcement.filter(visible=True)
|
announcement = announcement.filter(visible=True)
|
||||||
return self.success(paginate_data(request, announcement, AnnouncementSerializer))
|
return self.success(self.paginate_data(request, announcement, AnnouncementSerializer))
|
||||||
|
@ -19,9 +19,9 @@ class WebsiteConfig(models.Model):
|
|||||||
base_url = models.CharField(max_length=128, default=None)
|
base_url = models.CharField(max_length=128, default=None)
|
||||||
name = models.CharField(max_length=32, default="Online Judge")
|
name = models.CharField(max_length=32, default="Online Judge")
|
||||||
name_shortcut = models.CharField(max_length=32, default="oj")
|
name_shortcut = models.CharField(max_length=32, default="oj")
|
||||||
website_footer = models.CharField(max_length=256, default="Online Judge")
|
website_footer = models.TextField(default="Online Judge")
|
||||||
# allow register
|
# allow register
|
||||||
register = models.BooleanField(default=True)
|
allow_register = models.BooleanField(default=True)
|
||||||
# submission list show all user's submission
|
# submission list show all user's submission
|
||||||
submission_list_show_all = models.BooleanField(default=False)
|
submission_list_show_all = models.BooleanField(default=False)
|
||||||
|
|
||||||
|
2
utils/api/__init__.py
Normal file
2
utils/api/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .api import *
|
||||||
|
from ._serializers import *
|
17
utils/api/_serializers.py
Normal file
17
utils/api/_serializers.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
|
||||||
|
class JSONField(serializers.Field):
|
||||||
|
def to_representation(self, value):
|
||||||
|
return json.loads(value)
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeTZField(serializers.DateTimeField):
|
||||||
|
def to_representation(self, value):
|
||||||
|
self.format = "%Y-%-m-%d %-H:%-M:%-S"
|
||||||
|
value = timezone.localtime(value)
|
||||||
|
return super(DateTimeTZField, self).to_representation(value)
|
179
utils/api/api.py
Normal file
179
utils/api/api.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from django.http import HttpResponse, QueryDict
|
||||||
|
from django.utils.decorators import method_decorator
|
||||||
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
|
from django.views.generic import View
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContentType(object):
|
||||||
|
json_request = "application/json"
|
||||||
|
json_response = "application/json;charset=UTF-8"
|
||||||
|
url_encoded_request = "application/x-www-form-urlencoded"
|
||||||
|
binary_response = "application/octet-stream"
|
||||||
|
|
||||||
|
|
||||||
|
class JSONParser(object):
|
||||||
|
content_type = ContentType.json_request
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse(body):
|
||||||
|
return json.loads(body.decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
class URLEncodedParser(object):
|
||||||
|
content_type = ContentType.url_encoded_request
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse(body):
|
||||||
|
return QueryDict(body).dict()
|
||||||
|
|
||||||
|
|
||||||
|
class JSONResponse(object):
|
||||||
|
content_type = ContentType.json_response
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response(cls, data):
|
||||||
|
resp = HttpResponse(json.dumps(data, indent=4), content_type=cls.content_type)
|
||||||
|
resp.data = data
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
class APIView(View):
|
||||||
|
"""
|
||||||
|
Django view的父类, 和django-rest-framework的用法基本一致
|
||||||
|
- request.data获取解析之后的json或者urlencoded数据, dict类型
|
||||||
|
- self.success, self.error和self.invalid_serializer可以根据业需求修改,
|
||||||
|
写到父类中是为了不同的人开发写法统一,不再使用自己的success/error格式
|
||||||
|
- self.response 返回一个django HttpResponse, 具体在self.response_class中实现
|
||||||
|
- parse请求的类需要定义在request_parser中, 目前只支持json和urlencoded的类型, 用来解析请求的数据
|
||||||
|
"""
|
||||||
|
request_parsers = (JSONParser, URLEncodedParser)
|
||||||
|
response_class = JSONResponse
|
||||||
|
|
||||||
|
def _get_request_data(self, request):
|
||||||
|
if request.method != "GET":
|
||||||
|
body = request.body
|
||||||
|
content_type = request.META.get("CONTENT_TYPE")
|
||||||
|
if not content_type:
|
||||||
|
raise ValueError("content_type is required")
|
||||||
|
for parser in self.request_parsers:
|
||||||
|
if content_type.startswith(parser.content_type):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown content_type '%s'" % content_type)
|
||||||
|
if body:
|
||||||
|
return parser.parse(body)
|
||||||
|
return {}
|
||||||
|
return request.GET
|
||||||
|
|
||||||
|
def response(self, data):
|
||||||
|
return self.response_class.response(data)
|
||||||
|
|
||||||
|
def success(self, data=None):
|
||||||
|
return self.response({"error": None, "data": data})
|
||||||
|
|
||||||
|
def error(self, msg, err="error"):
|
||||||
|
return self.response({"error": err, "data": msg})
|
||||||
|
|
||||||
|
def invalid_serializer(self, serializer):
|
||||||
|
for k, v in serializer.errors.items():
|
||||||
|
if k != "non_field_errors":
|
||||||
|
return self.error(err="invalid-" + k, msg=k + ": " + v[0])
|
||||||
|
else:
|
||||||
|
return self.error(err="invalid-field", msg=k[0])
|
||||||
|
|
||||||
|
def server_error(self):
|
||||||
|
return self.error(err="server-error", msg="server error")
|
||||||
|
|
||||||
|
def paginate_data(self, request, query_set, object_serializer=None):
|
||||||
|
"""
|
||||||
|
:param request: django的request
|
||||||
|
:param query_set: django model的query set或者其他list like objects
|
||||||
|
:param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
need_paginate = request.GET.get("limit", None)
|
||||||
|
if need_paginate is None:
|
||||||
|
if object_serializer:
|
||||||
|
return object_serializer(query_set, many=True).data
|
||||||
|
else:
|
||||||
|
return query_set
|
||||||
|
try:
|
||||||
|
limit = int(request.GET.get("limit", "100"))
|
||||||
|
except ValueError:
|
||||||
|
limit = 100
|
||||||
|
if limit < 0:
|
||||||
|
limit = 100
|
||||||
|
try:
|
||||||
|
offset = int(request.GET.get("offset", "0"))
|
||||||
|
except ValueError:
|
||||||
|
offset = 0
|
||||||
|
if offset < 0:
|
||||||
|
offset = 0
|
||||||
|
results = query_set[offset:offset + limit]
|
||||||
|
if object_serializer:
|
||||||
|
count = query_set.count()
|
||||||
|
results = object_serializer(results, many=True).data
|
||||||
|
else:
|
||||||
|
count = len(query_set)
|
||||||
|
data = {"results": results,
|
||||||
|
"total": count}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def dispatch(self, request, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
request.data = self._get_request_data(self.request)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.error(err="invalid-request", msg=str(e))
|
||||||
|
try:
|
||||||
|
return super(APIView, self).dispatch(request, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
return self.server_error()
|
||||||
|
|
||||||
|
|
||||||
|
class CSRFExemptAPIView(APIView):
|
||||||
|
@method_decorator(csrf_exempt)
|
||||||
|
def dispatch(self, request, *args, **kwargs):
|
||||||
|
return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class SNServerAPIView(CSRFExemptAPIView):
|
||||||
|
def empty_response(self):
|
||||||
|
resp = HttpResponse()
|
||||||
|
resp["Content-Length"] = 0
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def response(self, data):
|
||||||
|
resp = super(SNServerAPIView, self).response(data)
|
||||||
|
resp["Content-Length"] = len(resp.content)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def validate_serializer(serializer):
|
||||||
|
"""
|
||||||
|
@validate_serializer(TestSerializer)
|
||||||
|
def post(self, request):
|
||||||
|
return self.success(request.data)
|
||||||
|
"""
|
||||||
|
def validate(view_method):
|
||||||
|
def handle(*args, **kwargs):
|
||||||
|
self = args[0]
|
||||||
|
request = args[1]
|
||||||
|
s = serializer(data=request.data)
|
||||||
|
if s.is_valid():
|
||||||
|
request.data = s.data
|
||||||
|
request.serializer = s
|
||||||
|
return view_method(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.invalid_serializer(s)
|
||||||
|
|
||||||
|
return handle
|
||||||
|
|
||||||
|
return validate
|
@ -1,6 +1,6 @@
|
|||||||
# coding=utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
from django.test.testcases import TestCase
|
from django.test.testcases import TestCase
|
||||||
|
from django.core.urlresolvers import reverse
|
||||||
|
|
||||||
from rest_framework.test import APIClient
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
from account.models import User, AdminType
|
from account.models import User, AdminType
|
||||||
@ -23,6 +23,9 @@ class APITestCase(TestCase):
|
|||||||
def create_super_admin(self, username="root", password="root", login=False):
|
def create_super_admin(self, username="root", password="root", login=False):
|
||||||
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, login=login)
|
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, login=login)
|
||||||
|
|
||||||
|
def reverse(self, url_name):
|
||||||
|
return reverse(url_name)
|
||||||
|
|
||||||
def assertSuccess(self, response):
|
def assertSuccess(self, response):
|
||||||
self.assertTrue(response.data["error"] is None)
|
self.assertTrue(response.data["error"] is None)
|
||||||
|
|
@ -1,18 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
import json
|
|
||||||
|
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
from rest_framework import serializers
|
|
||||||
|
|
||||||
|
|
||||||
class JSONField(serializers.Field):
|
|
||||||
def to_representation(self, value):
|
|
||||||
return json.loads(value)
|
|
||||||
|
|
||||||
|
|
||||||
class DateTimeTZField(serializers.DateTimeField):
|
|
||||||
def to_representation(self, value):
|
|
||||||
self.format = "%Y-%-m-%d %-H:%-M:%-S"
|
|
||||||
value = timezone.localtime(value)
|
|
||||||
return super(DateTimeTZField, self).to_representation(value)
|
|
@ -1,103 +1,25 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from django.http import HttpResponse
|
from django.utils.crypto import get_random_string
|
||||||
from django.views.generic import View
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def JSONResponse(data, content_type="application/json"):
|
|
||||||
resp = HttpResponse(json.dumps(data, indent=4), content_type=content_type)
|
|
||||||
resp.data = data
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
class APIView(View):
|
|
||||||
def _get_request_json(self, request):
|
|
||||||
if request.method != "GET":
|
|
||||||
body = request.body
|
|
||||||
if body:
|
|
||||||
return json.loads(body.decode("utf-8"))
|
|
||||||
return {}
|
|
||||||
return request.GET
|
|
||||||
|
|
||||||
def success(self, data=None):
|
|
||||||
return JSONResponse({"error": None, "data": data})
|
|
||||||
|
|
||||||
def error(self, message, error="error"):
|
|
||||||
return JSONResponse({"error": error, "data": message})
|
|
||||||
|
|
||||||
def invalid_serializer(self, serializer):
|
|
||||||
for k, v in serializer.errors.items():
|
|
||||||
return self.error(k + ": " + v[0], error="invalid-data-format")
|
|
||||||
|
|
||||||
def server_error(self):
|
|
||||||
return self.error("Server Error")
|
|
||||||
|
|
||||||
def dispatch(self, request, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
request.data = self._get_request_json(self.request)
|
|
||||||
except ValueError:
|
|
||||||
return self.error("Invalid JSON")
|
|
||||||
try:
|
|
||||||
return super(APIView, self).dispatch(request, *args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(e)
|
|
||||||
return self.server_error()
|
|
||||||
|
|
||||||
|
|
||||||
def paginate_data(request, query_set, object_serializer):
|
|
||||||
"""
|
|
||||||
function used to paginate data
|
|
||||||
"""
|
|
||||||
need_paginate = request.GET.get("paging", None)
|
|
||||||
# if paging=true not in request.GET, then we return all data
|
|
||||||
if need_paginate != "true":
|
|
||||||
if object_serializer:
|
|
||||||
return object_serializer(query_set, many=True).data
|
|
||||||
else:
|
|
||||||
return query_set
|
|
||||||
|
|
||||||
try:
|
|
||||||
limit = int(request.GET.get("limit", "100"))
|
|
||||||
except ValueError:
|
|
||||||
limit = 100
|
|
||||||
if limit < 0:
|
|
||||||
limit = 100
|
|
||||||
|
|
||||||
try:
|
|
||||||
offset = int(request.GET.get("offset", "0"))
|
|
||||||
except ValueError:
|
|
||||||
offset = 0
|
|
||||||
if offset < 0:
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
results = query_set[offset:offset + limit]
|
|
||||||
if object_serializer:
|
|
||||||
count = query_set.count()
|
|
||||||
results = object_serializer(results, many=True).data
|
|
||||||
else:
|
|
||||||
count = len(query_set)
|
|
||||||
|
|
||||||
data = {"results": results,
|
|
||||||
"count": count}
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def rand_str(length=32, type="lower_hex"):
|
def rand_str(length=32, type="lower_hex"):
|
||||||
"""
|
"""
|
||||||
generate types of random string or number with specific length
|
生成指定长度的随机字符串或者数字, 可以用于密钥等安全场景
|
||||||
DO NOT USE TO GENERATE SECRET KEY!
|
:param length: 字符串或者数字的长度
|
||||||
|
:param type: str 代表随机字符串,num 代表随机数字
|
||||||
|
:return: 字符串
|
||||||
"""
|
"""
|
||||||
if type == "str":
|
if type == "str":
|
||||||
return ''.join(random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") for i in range(length))
|
return get_random_string(length, allowed_chars="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")
|
||||||
elif type == "lower_str":
|
elif type == "lower_str":
|
||||||
return ''.join(random.choice("abcdefghijklmnopqrstuvwxyz0123456789") for i in range(length))
|
return get_random_string(length, allowed_chars="abcdefghijklmnopqrstuvwxyz0123456789")
|
||||||
elif type == "lower_hex":
|
elif type == "lower_hex":
|
||||||
return ''.join(random.choice("0123456789abcdef") for i in range(length))
|
return random.choice("123456789abcdef") + get_random_string(length - 1, allowed_chars="0123456789abcdef")
|
||||||
else:
|
else:
|
||||||
return random.choice("123456789") + ''.join(random.choice("0123456789") for i in range(length - 1))
|
return random.choice("123456789") + get_random_string(length - 1, allowed_chars="0123456789")
|
Loading…
Reference in New Issue
Block a user