diff --git a/account/migrations/0016_auto_20151211_2230.py b/account/migrations/0016_auto_20151211_2230.py new file mode 100644 index 00000000..9c2f8798 --- /dev/null +++ b/account/migrations/0016_auto_20151211_2230.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.9 on 2015-12-11 14:30 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('account', '0015_userprofile_student_id'), + ] + + operations = [ + migrations.AddField( + model_name='user', + name='tfa_token', + field=models.CharField(blank=True, max_length=10, null=True), + ), + migrations.AddField( + model_name='user', + name='two_factor_auth', + field=models.BooleanField(default=False), + ), + ] diff --git a/account/models.py b/account/models.py index b99c6917..98897a25 100644 --- a/account/models.py +++ b/account/models.py @@ -40,6 +40,9 @@ class User(AbstractBaseUser): reset_password_token_create_time = models.DateTimeField(blank=True, null=True) # 论坛授权token auth_token = models.CharField(max_length=40, blank=True, null=True) + # 是否开启两步验证 + two_factor_auth = models.BooleanField(default=False) + tfa_token = models.CharField(max_length=10, blank=True, null=True) USERNAME_FIELD = 'username' REQUIRED_FIELDS = [] diff --git a/account/serializers.py b/account/serializers.py index 0d92ae7e..ebae3558 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -84,3 +84,7 @@ class UserProfileSerializer(serializers.ModelSerializer): model = UserProfile fields = ["avatar", "blog", "mood", "hduoj_username", "bestcoder_username", "codeforces_username", "rank", "accepted_number", "submissions_number", "problems_status", "phone_number", "school", "student_id"] + + +class ApplyTwoFactorAuthSerializer(serializers.Serializer): + code = serializers.IntegerField() diff --git a/account/views.py b/account/views.py index 5ca7612a..88919c75 100644 --- a/account/views.py +++ b/account/views.py @@ -1,11 +1,13 @@ # coding=utf-8 import codecs +import qrcode +import StringIO from django import http from django.contrib import auth from django.shortcuts import render from django.db.models import Q from django.conf import settings -from django.http import HttpResponseRedirect +from django.http import HttpResponse from django.core.exceptions import MultipleObjectsReturned from django.utils.timezone import now @@ -15,6 +17,7 @@ from utils.shortcuts import (serializer_invalid_response, error_response, success_response, error_page, paginate, rand_str) from utils.captcha import Captcha from utils.mail import send_email +from utils.otp_auth import OtpAuth from .decorators import login_required from .models import User, UserProfile @@ -23,7 +26,8 @@ from .serializers import (UserLoginSerializer, UserRegisterSerializer, UserChangePasswordSerializer, UserSerializer, EditUserSerializer, ApplyResetPasswordSerializer, ResetPasswordSerializer, - SSOSerializer, EditUserProfileSerializer, UserProfileSerializer) + SSOSerializer, EditUserProfileSerializer, + UserProfileSerializer, ApplyTwoFactorAuthSerializer) from .decorators import super_admin_required @@ -151,9 +155,9 @@ class EmailCheckAPIView(APIView): 检测邮箱是否存在,用状态码标识结果 --- """ - #这里是为了适应前端表单验证空间的要求 + # 这里是为了适应前端表单验证空间的要求 reset = request.GET.get("reset", None) - #如果reset为true说明该请求是重置密码页面发出的,要返回的状态码应正好相反 + # 如果reset为true说明该请求是重置密码页面发出的,要返回的状态码应正好相反 if reset: existed = 200 does_not_existed = 400 @@ -375,3 +379,41 @@ def reset_password_page(request, token): if (now() - user.reset_password_token_create_time).total_seconds() > 30 * 60: return error_page(request, u"链接已过期") return render(request, "oj/account/reset_password.html", {"user": user}) + + +class TwoFactorAuthAPIView(APIView): + @login_required + def get(self, request): + """ + 获取绑定二维码 + """ + user = request.user + if user.two_factor_auth: + return error_response(u"已经开启两步验证了") + token = rand_str() + user.tfa_token = token + user.save() + + image = qrcode.make(OtpAuth(token).to_uri("totp", "OnlineJudge", "OnlineJudge")) + buf = StringIO.StringIO() + image.save(buf, 'gif') + + return HttpResponse(buf.getvalue(), 'image/gif') + + @login_required + def post(self, request): + """ + 开启两步验证 + """ + serializer = ApplyTwoFactorAuthSerializer(data=request.data) + if serializer.is_valid(): + code = serializer.data["code"] + user = request.user + if OtpAuth(user.tfa_token).valid_totp(code): + user.two_factor_auth = True + user.save() + return success_response(u"开启两步验证成功") + else: + return error_response(u"验证码错误") + else: + return serializer_invalid_response(serializer) \ No newline at end of file diff --git a/oj/urls.py b/oj/urls.py index 77a92b66..94e9d00f 100644 --- a/oj/urls.py +++ b/oj/urls.py @@ -6,7 +6,8 @@ from django.views.generic import TemplateView from account.views import (UserLoginAPIView, UsernameCheckAPIView, UserRegisterAPIView, UserChangePasswordAPIView, EmailCheckAPIView, UserAdminAPIView, UserInfoAPIView, ResetPasswordAPIView, - ApplyResetPasswordAPIView, SSOAPIView, UserProfileAPIView) + ApplyResetPasswordAPIView, SSOAPIView, UserProfileAPIView, + TwoFactorAuthAPIView) from announcement.views import AnnouncementAdminAPIView @@ -132,7 +133,8 @@ urlpatterns = [ url(r'^account/sso/$', SSOAPIView.as_view(), name="sso_api"), url(r'^api/account/userprofile/$', UserProfileAPIView.as_view(), name="userprofile_api"), url(r'^reset_password/$', TemplateView.as_view(template_name="oj/account/apply_reset_password.html"), name="apply_reset_password_page"), - url(r'^reset_password/t/(?P\w+)/$', "account.views.reset_password_page", name="reset_password_page") + url(r'^reset_password/t/(?P\w+)/$', "account.views.reset_password_page", name="reset_password_page"), + url(r'^api/two_factor_auth/$', TwoFactorAuthAPIView.as_view(), name="two_factor_auth_api"), ] diff --git a/utils/otp_auth.py b/utils/otp_auth.py new file mode 100644 index 00000000..12773c4e --- /dev/null +++ b/utils/otp_auth.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- +""" + otpauth + ~~~~~~~ + + Implements two-step verification of HOTP/TOTP. + + :copyright: (c) 2013 - 2015 by Hsiaoming Yang. + :license: BSD, see LICENSE for more details. +""" +import sys +import time +import hmac +import base64 +import struct +import hashlib +import warnings + + +if sys.version_info[0] == 3: + PY2 = False + string_type = str +else: + PY2 = True + string_type = unicode + range = xrange + + +__author__ = 'Hsiaoming Yang ' +__homepage__ = 'https://github.com/lepture/otpauth' +__version__ = '1.0.1' + + +__all__ = ['OtpAuth', 'HOTP', 'TOTP', 'generate_hotp', 'generate_totp'] + + +HOTP = 'hotp' +TOTP = 'totp' + + +class OtpAuth(object): + """One Time Password Authentication. + + :param secret: A secret token for the authentication. + """ + + def __init__(self, secret): + self.secret = secret + + def hotp(self, counter=4): + """Generate a HOTP code. + + :param counter: HOTP is a counter based algorithm. + """ + return generate_hotp(self.secret, counter) + + def totp(self, period=30, timestamp=None): + """Generate a TOTP code. + + A TOTP code is an extension of HOTP algorithm. + + :param period: A period that a TOTP code is valid in seconds + :param timestamp: Create TOTP at this given timestamp + """ + return generate_totp(self.secret, period, timestamp) + + def valid_hotp(self, code, last=0, trials=100): + """Valid a HOTP code. + + :param code: A number that is less than 6 characters. + :param last: Guess HOTP code from last + 1 range. + :param trials: Guest HOTP code end at last + trials + 1. + """ + if not valid_code(code): + return False + + code = bytes(int(code)) + for i in range(last + 1, last + trials + 1): + if compare_digest(bytes(self.hotp(counter=i)), code): + return i + return False + + def valid_totp(self, code, period=30, timestamp=None): + """Valid a TOTP code. + + :param code: A number that is less than 6 characters. + :param period: A period that a TOTP code is valid in seconds + :param timestamp: Validate TOTP at this given timestamp + """ + if not valid_code(code): + return False + return compare_digest( + bytes(self.totp(period, timestamp)), + bytes(int(code)) + ) + + @property + def encoded_secret(self): + secret = base64.b32encode(to_bytes(self.secret)) + # bytes to string + secret = secret.decode('utf-8') + # remove pad string + return secret.strip('=') + + def to_uri(self, type, label, issuer, counter=None): + """Generate the otpauth protocal string. + + :param type: Algorithm type, hotp or totp. + :param label: Label of the identifier. + :param issuer: The company, the organization or something else. + :param counter: Counter of the HOTP algorithm. + """ + type = type.lower() + + if type not in ('hotp', 'totp'): + raise ValueError('type must be hotp or totp') + + if type == 'hotp' and not counter: + raise ValueError('HOTP type authentication need counter') + + # https://code.google.com/p/google-authenticator/wiki/KeyUriFormat + url = ('otpauth://%(type)s/%(label)s?secret=%(secret)s' + '&issuer=%(issuer)s') + dct = dict( + type=type, label=label, issuer=issuer, + secret=self.encoded_secret, counter=counter + ) + ret = url % dct + if type == 'hotp': + ret = '%s&counter=%s' % (ret, counter) + return ret + + def to_google(self, type, label, issuer, counter=None): + """Generate the otpauth protocal string for Google Authenticator. + + .. deprecated:: 0.2.0 + Use :func:`to_uri` instead. + """ + warnings.warn('deprecated, use to_uri instead', DeprecationWarning) + return self.to_uri(type, label, issuer, counter) + + +def generate_hotp(secret, counter=4): + """Generate a HOTP code. + + :param secret: A secret token for the authentication. + :param counter: HOTP is a counter based algorithm. + """ + # https://tools.ietf.org/html/rfc4226 + msg = struct.pack('>Q', counter) + digest = hmac.new(to_bytes(secret), msg, hashlib.sha1).digest() + + ob = digest[19] + if PY2: + ob = ord(ob) + + pos = ob & 15 + base = struct.unpack('>I', digest[pos:pos + 4])[0] & 0x7fffffff + token = base % 1000000 + return token + + +def generate_totp(secret, period=30, timestamp=None): + """Generate a TOTP code. + + A TOTP code is an extension of HOTP algorithm. + + :param secret: A secret token for the authentication. + :param period: A period that a TOTP code is valid in seconds + :param timestamp: Current time stamp. + """ + if timestamp is None: + timestamp = time.time() + counter = int(timestamp) // period + return generate_hotp(secret, counter) + + +def to_bytes(text): + if isinstance(text, string_type): + # Python3 str -> bytes + # Python2 unicode -> str + text = text.encode('utf-8') + return text + + +def valid_code(code): + code = string_type(code) + return code.isdigit() and len(code) <= 6 + + +def compare_digest(a, b): + func = getattr(hmac, 'compare_digest', None) + if func: + return func(a, b) + + # fallback + if len(a) != len(b): + return False + + rv = 0 + if PY2: + from itertools import izip + for x, y in izip(a, b): + rv |= ord(x) ^ ord(y) + else: + for x, y in zip(a, b): + rv |= x ^ y + return rv == 0 \ No newline at end of file