1
0
Fork 0

Compare commits

..

1 Commits

Author SHA1 Message Date
Fuxinyan 1c08e80236 Fix Bug 536 2024-06-04 13:58:47 +08:00
3 changed files with 165 additions and 31 deletions

View File

@ -1,8 +1,12 @@
import hashlib import hashlib
import string import string
from datetime import datetime, timedelta from datetime import datetime, timedelta
import unicodedata
from UseSqlite import InsertQuery, RecordQuery from UseSqlite import InsertQuery, RecordQuery
def md5(s): def md5(s):
''' '''
MD5摘要 MD5摘要
@ -12,14 +16,16 @@ def md5(s):
h = hashlib.md5(s.encode(encoding='utf-8')) h = hashlib.md5(s.encode(encoding='utf-8'))
return h.hexdigest() return h.hexdigest()
# import model.user after the defination of md5(s) to avoid circular import # import model.user after the defination of md5(s) to avoid circular import
from model.user import get_user_by_username, insert_user, update_password_by_username from model.user import get_user_by_username, insert_user, update_password_by_username
path_prefix = '/var/www/wordfreq/wordfreq/' path_prefix = '/var/www/wordfreq/wordfreq/'
path_prefix = './' # comment this line in deployment path_prefix = './' # comment this line in deployment
def verify_pass(newpass,oldpass):
if(newpass==oldpass): def verify_pass(newpass, oldpass):
if (newpass == oldpass):
return True return True
@ -31,7 +37,7 @@ def verify_user(username, password):
def add_user(username, password): def add_user(username, password):
start_date = datetime.now().strftime('%Y%m%d') start_date = datetime.now().strftime('%Y%m%d')
expiry_date = (datetime.now() + timedelta(days=30)).strftime('%Y%m%d') # will expire after 30 days expiry_date = (datetime.now() + timedelta(days=30)).strftime('%Y%m%d') # will expire after 30 days
# 将用户名和密码一起加密,以免暴露不同用户的相同密码 # 将用户名和密码一起加密,以免暴露不同用户的相同密码
password = md5(username + password) password = md5(username + password)
insert_user(username=username, password=password, start_date=start_date, expiry_date=expiry_date) insert_user(username=username, password=password, start_date=start_date, expiry_date=expiry_date)
@ -53,7 +59,7 @@ def change_password(username, old_password, new_password):
if not verify_user(username, old_password): # 旧密码错误 if not verify_user(username, old_password): # 旧密码错误
return False return False
# 将用户名和密码一起加密,以免暴露不同用户的相同密码 # 将用户名和密码一起加密,以免暴露不同用户的相同密码
if verify_pass(new_password,old_password): #新旧密码一致 if verify_pass(new_password, old_password): #新旧密码一致
return False return False
update_password_by_username(username, new_password) update_password_by_username(username, new_password)
return True return True
@ -66,30 +72,64 @@ def get_expiry_date(username):
else: else:
return user.expiry_date return user.expiry_date
class UserName: class UserName:
def __init__(self, username): def __init__(self, username):
self.username = username self.username = username
def contains_chinese(self):
for char in self.username:
# Check if the character is in the CJK (Chinese, Japanese, Korean) Unicode block
if unicodedata.name(char).startswith('CJK UNIFIED IDEOGRAPH'):
return True
return False
def validate(self): def validate(self):
if len(self.username) > 20: if len(self.username) > 20:
return f'{self.username} is too long. The user name cannot exceed 20 characters.' return f'{self.username} is too long. The user name cannot exceed 20 characters.'
if self.username.startswith('.'): # a user name must not start with a dot if self.username.startswith('.'): # a user name must not start with a dot
return 'Period (.) is not allowed as the first letter in the user name.' return 'Period (.) is not allowed as the first letter in the user name.'
if ' ' in self.username: # a user name must not include a whitespace if ' ' in self.username: # a user name must not include a whitespace
return 'Whitespace is not allowed in the user name.' return 'Whitespace is not allowed in the user name.'
for c in self.username: # a user name must not include special characters, except non-leading periods or underscores for c in self.username: # a user name must not include special characters, except non-leading periods or underscores
if c in string.punctuation and c != '.' and c != '_': if c in string.punctuation and c != '.' and c != '_':
return f'{c} is not allowed in the user name.' return f'{c} is not allowed in the user name.'
if self.username in ['signup', 'login', 'logout', 'reset', 'mark', 'back', 'unfamiliar', 'familiar', 'del', 'admin']: if self.username in ['signup', 'login', 'logout', 'reset', 'mark', 'back', 'unfamiliar', 'familiar', 'del',
'admin']:
return 'You used a restricted word as your user name. Please come up with a better one.' return 'You used a restricted word as your user name. Please come up with a better one.'
if self.contains_chinese():
return 'Chinese characters are not allowed in the user name.'
return 'OK'
class Password:
def __init__(self, password):
self.password = password
def contains_chinese(self):
for char in self.password:
# Check if the character is in the CJK (Chinese, Japanese, Korean) Unicode block
if unicodedata.name(char).startswith('CJK UNIFIED IDEOGRAPH'):
return True
return False
def validate(self):
if len(self.password) < 4:
return 'Password must be at least 4 characters long.'
if ' ' in self.password:
return 'Password cannot contain spaces.'
if self.contains_chinese():
return 'Chinese characters are not allowed in the password.'
return 'OK' return 'OK'
class WarningMessage: class WarningMessage:
def __init__(self, s): def __init__(self, s, type='username'):
self.s = s self.s = s
self.type = type
def __str__(self): def __str__(self):
return UserName(self.s).validate() if self.type == 'username':
return UserName(self.s).validate()
if self.type == 'password':
return Password(self.s).validate()

View File

@ -1,5 +1,6 @@
# System Library # System Library
from flask import * from flask import *
from markupsafe import escape
# Personal library # Personal library
from Yaml import yml from Yaml import yml
@ -37,6 +38,22 @@ def admin():
@adminService.route("/admin/article", methods=["GET", "POST"]) @adminService.route("/admin/article", methods=["GET", "POST"])
def article(): def article():
def _make_title_and_content(article_lst):
for article in article_lst:
text = escape(article.text) # Fix XSS vulnerability, contributed by Xu Xuan
article.title = text.split("\n")[0]
article.content = '<br/>'.join(text.split("\n")[1:])
def _update_context():
article_len = get_number_of_articles()
context["article_number"] = article_len
context["text_list"] = get_page_articles(_cur_page, _page_size)
_articles = get_page_articles(_cur_page, _page_size)
_make_title_and_content(_articles)
context["text_list"] = _articles
global _cur_page, _page_size global _cur_page, _page_size
is_admin = check_is_admin() is_admin = check_is_admin()
@ -44,20 +61,15 @@ def article():
return is_admin return is_admin
_article_number = get_number_of_articles() _article_number = get_number_of_articles()
try: try:
_page_size = min( _page_size = min(max(1, int(request.args.get("size", 5))), _article_number) # 最小的size是1
max(1, int(request.args.get("size", 5))), _article_number _cur_page = min(max(1, int(request.args.get("page", 1))), _article_number // _page_size + (_article_number % _page_size > 0)) # 最小的page是1
) # 最小的size是1
_cur_page = min(
max(1, int(request.args.get("page", 1))), _article_number // _page_size + (_article_number % _page_size > 0)
) # 最小的page是1
except ValueError: except ValueError:
return "page parmas must be int!" return "page parameters must be integer!"
_articles = get_page_articles(_cur_page, _page_size) _articles = get_page_articles(_cur_page, _page_size)
for article in _articles: # 获取每篇文章的title _make_title_and_content(_articles)
article.title = article.text.split("\n")[0]
article.content = '<br/>'.join(article.text.split("\n")[1:])
context = { context = {
"article_number": _article_number, "article_number": _article_number,
@ -67,23 +79,16 @@ def article():
"username": session.get("username"), "username": session.get("username"),
} }
def _update_context():
article_len = get_number_of_articles()
context["article_number"] = article_len
context["text_list"] = get_page_articles(_cur_page, _page_size)
_articles = get_page_articles(_cur_page, _page_size)
for article in _articles: # 获取每篇文章的title
article.title = article.text.split("\n")[0]
context["text_list"] = _articles
if request.method == "GET": if request.method == "GET":
try: try:
delete_id = int(request.args.get("delete_id", 0)) delete_id = int(request.args.get("delete_id", 0))
except: except:
return "Delete article ID must be int!" return "Delete article ID must be integer!"
if delete_id: # delete article if delete_id: # delete article
delete_article_by_id(delete_id) delete_article_by_id(delete_id)
_update_context() _update_context()
elif request.method == "POST": elif request.method == "POST":
data = request.form data = request.form
content = data.get("content", "") content = data.get("content", "")
@ -97,6 +102,7 @@ def article():
_update_context() _update_context()
title = content.split('\n')[0] title = content.split('\n')[0]
flash(f'Article added. Title: {title}') flash(f'Article added. Title: {title}')
return render_template("admin_manage_article.html", **context) return render_template("admin_manage_article.html", **context)

View File

@ -0,0 +1,88 @@
from selenium.webdriver.common.alert import Alert
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
# 对用户名不能为中文进行测试
def test_register_username_with_chinese(driver, URL):
try:
driver.get(URL + "/signup")
# 等待用户名输入框出现
username_elem = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'username'))
)
username_elem.send_keys("测试用户") # 输入中文用户名
# 等待密码输入框出现
password_elem = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'password'))
)
password_elem.send_keys("validPassword123") # 输入有效密码
# 等待确认密码输入框出现
password2_elem = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'password2'))
)
password2_elem.send_keys("validPassword123") # 输入有效确认密码
# 等待注册按钮出现并点击
signup_button = WebDriverWait(driver, 10).until(
EC.element_to_be_clickable((By.XPATH, '//button[@onclick="signup()"]'))
)
signup_button.click()
# 等待警告框出现并接受
WebDriverWait(driver, 10).until(EC.alert_is_present())
alert = driver.switch_to.alert
alert_text = alert.text
print(f"警告文本: {alert_text}")
assert alert_text == "Chinese characters are not allowed in the user name." # 根据实际的警告文本进行断言
alert.accept()
except Exception as e:
print(f"发生错误: {e}")
raise
# 对注册时密码不能是中文进行测试
def test_register_password_with_chinese(driver, URL):
try:
driver.get(URL + "/signup")
# 等待用户名输入框出现
username_elem = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'username'))
)
username_elem.send_keys("validUsername123") # 输入有效用户名
# 等待密码输入框出现
password_elem = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'password'))
)
password_elem.send_keys("测试密码") # 输入中文密码
# 等待确认密码输入框出现
password2_elem = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'password2'))
)
password2_elem.send_keys("测试密码") # 输入中文确认密码
# 等待注册按钮出现并点击
signup_button = WebDriverWait(driver, 10).until(
EC.element_to_be_clickable((By.XPATH, '//button[@onclick="signup()"]'))
)
signup_button.click()
# 等待警告框出现并接受
WebDriverWait(driver, 10).until(EC.alert_is_present())
alert = driver.switch_to.alert
alert_text = alert.text
print(f"警告文本: {alert_text}")
assert alert_text == "Chinese characters are not allowed in the password." # 根据实际的警告文本进行断言
alert.accept()
except Exception as e:
print(f"发生错误: {e}")
raise