diff --git a/conf/tests.py b/conf/tests.py
index f6debd6..55f4853 100644
--- a/conf/tests.py
+++ b/conf/tests.py
@@ -66,10 +66,11 @@ class WebsiteConfigAPITest(APITestCase):
self.create_super_admin()
url = self.reverse("website_config_api")
data = {"website_base_url": "http://test.com", "website_name": "test name",
- "website_name_shortcut": "test oj", "website_footer": "test",
+ "website_name_shortcut": "test oj", "website_footer": "",
"allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data)
self.assertSuccess(resp)
+ self.assertEqual(SysOptions.website_footer, "
")
def test_get_website_config(self):
# do not need to login
diff --git a/conf/views.py b/conf/views.py
index a4f070c..b4e47bb 100644
--- a/conf/views.py
+++ b/conf/views.py
@@ -13,6 +13,7 @@ from judge.languages import languages, spj_languages
from options.options import SysOptions
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.shortcuts import send_email
+from utils.xss_filter import XSSHtml
from .models import JudgeServer
from .serializers import (CreateEditWebsiteConfigSerializer,
CreateSMTPConfigSerializer, EditSMTPConfigSerializer,
@@ -84,6 +85,9 @@ class WebsiteConfigAPI(APIView):
@super_admin_required
def post(self, request):
for k, v in request.data.items():
+ if k == "website_footer":
+ with XSSHtml() as parser:
+ v = parser.clean(v)
setattr(SysOptions, k, v)
return self.success()
diff --git a/utils/models.py b/utils/models.py
index 3c11452..9a1cd0b 100644
--- a/utils/models.py
+++ b/utils/models.py
@@ -1,14 +1,10 @@
from django.contrib.postgres.fields import JSONField # NOQA
from django.db import models
-from utils.xss_filter import XssHtml
+from utils.xss_filter import XSSHtml
class RichTextField(models.TextField):
def get_prep_value(self, value):
- if not value:
- value = ""
- parser = XssHtml()
- parser.feed(value)
- parser.close()
- return parser.getHtml()
+ with XSSHtml() as parser:
+ return parser.clean(value or "")
diff --git a/utils/xss_filter.py b/utils/xss_filter.py
index 34d65a8..1b45d89 100644
--- a/utils/xss_filter.py
+++ b/utils/xss_filter.py
@@ -30,7 +30,7 @@ import copy
from html.parser import HTMLParser
-class XssHtml(HTMLParser):
+class XSSHtml(HTMLParser):
allow_tags = ['a', 'img', 'br', 'strong', 'b', 'code', 'pre',
'p', 'div', 'em', 'span', 'h1', 'h2', 'h3', 'h4',
'h5', 'h6', 'blockquote', 'ul', 'ol', 'tr', 'th', 'td',
@@ -53,7 +53,17 @@ class XssHtml(HTMLParser):
self.start = []
self.data = []
- def getHtml(self):
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ super().close()
+
+ def clean(self, content):
+ self.feed(content)
+ return self.get_html()
+
+ def get_html(self):
"""
Get the safe html code
"""
@@ -188,11 +198,11 @@ class XssHtml(HTMLParser):
if "__main__" == __name__:
- parser = XssHtml()
- parser.feed("""
>M
- """) - parser.close() - print(parser.getHtml()) + with XSSHtml() as parser: + ret = parser.clean(""">M
+ +