From 1c08e8023653bb9129fb93f5578f71d841b92b47 Mon Sep 17 00:00:00 2001
From: Fuxinyan <2386457645@qq.com>
Date: Tue, 4 Jun 2024 13:58:47 +0800
Subject: [PATCH] Fix Bug 536
---
app/Login.py | 62 ++++++++++++++++----
app/admin_service.py | 46 ++++++++-------
app/test/test_bug536_jiangwangzhe.py | 88 ++++++++++++++++++++++++++++
3 files changed, 165 insertions(+), 31 deletions(-)
create mode 100644 app/test/test_bug536_jiangwangzhe.py
diff --git a/app/Login.py b/app/Login.py
index cd750d1..9445748 100644
--- a/app/Login.py
+++ b/app/Login.py
@@ -1,8 +1,12 @@
import hashlib
import string
from datetime import datetime, timedelta
+
+import unicodedata
+
from UseSqlite import InsertQuery, RecordQuery
+
def md5(s):
'''
MD5摘要
@@ -12,14 +16,16 @@ def md5(s):
h = hashlib.md5(s.encode(encoding='utf-8'))
return h.hexdigest()
+
# import model.user after the defination of md5(s) to avoid circular import
from model.user import get_user_by_username, insert_user, update_password_by_username
path_prefix = '/var/www/wordfreq/wordfreq/'
path_prefix = './' # comment this line in deployment
-def verify_pass(newpass,oldpass):
- if(newpass==oldpass):
+
+def verify_pass(newpass, oldpass):
+ if (newpass == oldpass):
return True
@@ -31,7 +37,7 @@ def verify_user(username, password):
def add_user(username, password):
start_date = datetime.now().strftime('%Y%m%d')
- expiry_date = (datetime.now() + timedelta(days=30)).strftime('%Y%m%d') # will expire after 30 days
+ expiry_date = (datetime.now() + timedelta(days=30)).strftime('%Y%m%d') # will expire after 30 days
# 将用户名和密码一起加密,以免暴露不同用户的相同密码
password = md5(username + password)
insert_user(username=username, password=password, start_date=start_date, expiry_date=expiry_date)
@@ -53,7 +59,7 @@ def change_password(username, old_password, new_password):
if not verify_user(username, old_password): # 旧密码错误
return False
# 将用户名和密码一起加密,以免暴露不同用户的相同密码
- if verify_pass(new_password,old_password): #新旧密码一致
+ if verify_pass(new_password, old_password): #新旧密码一致
return False
update_password_by_username(username, new_password)
return True
@@ -66,30 +72,64 @@ def get_expiry_date(username):
else:
return user.expiry_date
+
class UserName:
def __init__(self, username):
self.username = username
+ def contains_chinese(self):
+ for char in self.username:
+ # Check if the character is in the CJK (Chinese, Japanese, Korean) Unicode block
+ if unicodedata.name(char).startswith('CJK UNIFIED IDEOGRAPH'):
+ return True
+ return False
+
def validate(self):
if len(self.username) > 20:
return f'{self.username} is too long. The user name cannot exceed 20 characters.'
- if self.username.startswith('.'): # a user name must not start with a dot
+ if self.username.startswith('.'): # a user name must not start with a dot
return 'Period (.) is not allowed as the first letter in the user name.'
- if ' ' in self.username: # a user name must not include a whitespace
+ if ' ' in self.username: # a user name must not include a whitespace
return 'Whitespace is not allowed in the user name.'
- for c in self.username: # a user name must not include special characters, except non-leading periods or underscores
+ for c in self.username: # a user name must not include special characters, except non-leading periods or underscores
if c in string.punctuation and c != '.' and c != '_':
return f'{c} is not allowed in the user name.'
- if self.username in ['signup', 'login', 'logout', 'reset', 'mark', 'back', 'unfamiliar', 'familiar', 'del', 'admin']:
+ if self.username in ['signup', 'login', 'logout', 'reset', 'mark', 'back', 'unfamiliar', 'familiar', 'del',
+ 'admin']:
return 'You used a restricted word as your user name. Please come up with a better one.'
+ if self.contains_chinese():
+ return 'Chinese characters are not allowed in the user name.'
+ return 'OK'
+
+class Password:
+ def __init__(self, password):
+ self.password = password
+
+ def contains_chinese(self):
+ for char in self.password:
+ # Check if the character is in the CJK (Chinese, Japanese, Korean) Unicode block
+ if unicodedata.name(char).startswith('CJK UNIFIED IDEOGRAPH'):
+ return True
+ return False
+
+ def validate(self):
+ if len(self.password) < 4:
+ return 'Password must be at least 4 characters long.'
+ if ' ' in self.password:
+ return 'Password cannot contain spaces.'
+ if self.contains_chinese():
+ return 'Chinese characters are not allowed in the password.'
return 'OK'
class WarningMessage:
- def __init__(self, s):
+ def __init__(self, s, type='username'):
self.s = s
+ self.type = type
def __str__(self):
- return UserName(self.s).validate()
-
+ if self.type == 'username':
+ return UserName(self.s).validate()
+ if self.type == 'password':
+ return Password(self.s).validate()
diff --git a/app/admin_service.py b/app/admin_service.py
index a604b5e..c461af9 100644
--- a/app/admin_service.py
+++ b/app/admin_service.py
@@ -1,5 +1,6 @@
# System Library
from flask import *
+from markupsafe import escape
# Personal library
from Yaml import yml
@@ -37,6 +38,22 @@ def admin():
@adminService.route("/admin/article", methods=["GET", "POST"])
def article():
+
+ def _make_title_and_content(article_lst):
+ for article in article_lst:
+ text = escape(article.text) # Fix XSS vulnerability, contributed by Xu Xuan
+ article.title = text.split("\n")[0]
+ article.content = '
'.join(text.split("\n")[1:])
+
+
+ def _update_context():
+ article_len = get_number_of_articles()
+ context["article_number"] = article_len
+ context["text_list"] = get_page_articles(_cur_page, _page_size)
+ _articles = get_page_articles(_cur_page, _page_size)
+ _make_title_and_content(_articles)
+ context["text_list"] = _articles
+
global _cur_page, _page_size
is_admin = check_is_admin()
@@ -44,20 +61,15 @@ def article():
return is_admin
_article_number = get_number_of_articles()
+
try:
- _page_size = min(
- max(1, int(request.args.get("size", 5))), _article_number
- ) # 最小的size是1
- _cur_page = min(
- max(1, int(request.args.get("page", 1))), _article_number // _page_size + (_article_number % _page_size > 0)
- ) # 最小的page是1
+ _page_size = min(max(1, int(request.args.get("size", 5))), _article_number) # 最小的size是1
+ _cur_page = min(max(1, int(request.args.get("page", 1))), _article_number // _page_size + (_article_number % _page_size > 0)) # 最小的page是1
except ValueError:
- return "page parmas must be int!"
-
+ return "page parameters must be integer!"
+
_articles = get_page_articles(_cur_page, _page_size)
- for article in _articles: # 获取每篇文章的title
- article.title = article.text.split("\n")[0]
- article.content = '
'.join(article.text.split("\n")[1:])
+ _make_title_and_content(_articles)
context = {
"article_number": _article_number,
@@ -67,23 +79,16 @@ def article():
"username": session.get("username"),
}
- def _update_context():
- article_len = get_number_of_articles()
- context["article_number"] = article_len
- context["text_list"] = get_page_articles(_cur_page, _page_size)
- _articles = get_page_articles(_cur_page, _page_size)
- for article in _articles: # 获取每篇文章的title
- article.title = article.text.split("\n")[0]
- context["text_list"] = _articles
if request.method == "GET":
try:
delete_id = int(request.args.get("delete_id", 0))
except:
- return "Delete article ID must be int!"
+ return "Delete article ID must be integer!"
if delete_id: # delete article
delete_article_by_id(delete_id)
_update_context()
+
elif request.method == "POST":
data = request.form
content = data.get("content", "")
@@ -97,6 +102,7 @@ def article():
_update_context()
title = content.split('\n')[0]
flash(f'Article added. Title: {title}')
+
return render_template("admin_manage_article.html", **context)
diff --git a/app/test/test_bug536_jiangwangzhe.py b/app/test/test_bug536_jiangwangzhe.py
new file mode 100644
index 0000000..4862486
--- /dev/null
+++ b/app/test/test_bug536_jiangwangzhe.py
@@ -0,0 +1,88 @@
+from selenium.webdriver.common.alert import Alert
+from selenium.webdriver.common.by import By
+from selenium.webdriver.support.ui import WebDriverWait
+from selenium.webdriver.support import expected_conditions as EC
+
+
+# 对用户名不能为中文进行测试
+def test_register_username_with_chinese(driver, URL):
+ try:
+ driver.get(URL + "/signup")
+
+ # 等待用户名输入框出现
+ username_elem = WebDriverWait(driver, 10).until(
+ EC.presence_of_element_located((By.ID, 'username'))
+ )
+ username_elem.send_keys("测试用户") # 输入中文用户名
+
+ # 等待密码输入框出现
+ password_elem = WebDriverWait(driver, 10).until(
+ EC.presence_of_element_located((By.ID, 'password'))
+ )
+ password_elem.send_keys("validPassword123") # 输入有效密码
+
+ # 等待确认密码输入框出现
+ password2_elem = WebDriverWait(driver, 10).until(
+ EC.presence_of_element_located((By.ID, 'password2'))
+ )
+ password2_elem.send_keys("validPassword123") # 输入有效确认密码
+
+ # 等待注册按钮出现并点击
+ signup_button = WebDriverWait(driver, 10).until(
+ EC.element_to_be_clickable((By.XPATH, '//button[@onclick="signup()"]'))
+ )
+ signup_button.click()
+
+ # 等待警告框出现并接受
+ WebDriverWait(driver, 10).until(EC.alert_is_present())
+ alert = driver.switch_to.alert
+ alert_text = alert.text
+ print(f"警告文本: {alert_text}")
+ assert alert_text == "Chinese characters are not allowed in the user name." # 根据实际的警告文本进行断言
+ alert.accept()
+
+ except Exception as e:
+ print(f"发生错误: {e}")
+ raise
+
+
+# 对注册时密码不能是中文进行测试
+def test_register_password_with_chinese(driver, URL):
+ try:
+ driver.get(URL + "/signup")
+
+ # 等待用户名输入框出现
+ username_elem = WebDriverWait(driver, 10).until(
+ EC.presence_of_element_located((By.ID, 'username'))
+ )
+ username_elem.send_keys("validUsername123") # 输入有效用户名
+
+ # 等待密码输入框出现
+ password_elem = WebDriverWait(driver, 10).until(
+ EC.presence_of_element_located((By.ID, 'password'))
+ )
+ password_elem.send_keys("测试密码") # 输入中文密码
+
+ # 等待确认密码输入框出现
+ password2_elem = WebDriverWait(driver, 10).until(
+ EC.presence_of_element_located((By.ID, 'password2'))
+ )
+ password2_elem.send_keys("测试密码") # 输入中文确认密码
+
+ # 等待注册按钮出现并点击
+ signup_button = WebDriverWait(driver, 10).until(
+ EC.element_to_be_clickable((By.XPATH, '//button[@onclick="signup()"]'))
+ )
+ signup_button.click()
+
+ # 等待警告框出现并接受
+ WebDriverWait(driver, 10).until(EC.alert_is_present())
+ alert = driver.switch_to.alert
+ alert_text = alert.text
+ print(f"警告文本: {alert_text}")
+ assert alert_text == "Chinese characters are not allowed in the password." # 根据实际的警告文本进行断言
+ alert.accept()
+
+ except Exception as e:
+ print(f"发生错误: {e}")
+ raise