|
| 1 | +import asyncio |
1 | 2 | import datetime |
2 | 3 | import uuid |
3 | 4 | from pathlib import Path |
|
6 | 7 | from starlette.requests import Request |
7 | 8 | from starlette.responses import HTMLResponse, FileResponse |
8 | 9 | from starlette.staticfiles import StaticFiles |
9 | | -import os, shutil |
10 | | -from core.utils import error_ip_limit, upload_ip_limit, get_code, storage |
| 10 | +import os |
| 11 | +import shutil |
| 12 | +from core.utils import error_ip_limit, upload_ip_limit, get_code, storage, delete_expire_files |
11 | 13 | from core.depends import admin_required |
12 | | -from settings import settings |
13 | 14 | from fastapi import FastAPI, Depends, UploadFile, Form, File, HTTPException, BackgroundTasks |
14 | | - |
15 | 15 | from core.database import init_models, Options, Codes, get_session |
| 16 | +from settings import settings |
16 | 17 |
|
17 | 18 | # 实例化FastAPI |
18 | | -app = FastAPI(debug=settings.DEBUG, docs_url=None, redoc_url=None) |
| 19 | +app = FastAPI(debug=settings.DEBUG, redoc_url=None, ) |
| 20 | + |
| 21 | + |
| 22 | +@app.on_event('startup') |
| 23 | +async def startup(s: AsyncSession = Depends(get_session)): |
| 24 | + # 初始化数据库 |
| 25 | + await init_models(s) |
| 26 | + # 启动后台任务,不定时删除过期文件 |
| 27 | + asyncio.create_task(delete_expire_files()) |
| 28 | + |
19 | 29 |
|
20 | 30 | # 数据存储文件夹 |
21 | 31 | DATA_ROOT = Path(settings.DATA_ROOT) |
|
28 | 38 |
|
29 | 39 | # 静态文件夹,这个固定就行了,静态资源都放在这里 |
30 | 40 | app.mount('/static', StaticFiles(directory='./static'), name="static") |
| 41 | + |
31 | 42 | # 首页页面 |
32 | | -index_html = open('templates/index.html', 'r', encoding='utf-8').read() \ |
33 | | - .replace('{{title}}', settings.TITLE) \ |
34 | | - .replace('{{description}}', settings.DESCRIPTION) \ |
35 | | - .replace('{{keywords}}', settings.KEYWORDS) \ |
36 | | - .replace("'{{fileSizeLimit}}'", str(settings.FILE_SIZE_LIMIT)) |
| 43 | +index_html = open('templates/index.html', 'r', encoding='utf-8').read() |
37 | 44 | # 管理页面 |
38 | | -admin_html = open('templates/admin.html', 'r', encoding='utf-8').read() \ |
39 | | - .replace('{{title}}', settings.TITLE) \ |
40 | | - .replace('{{description}}', settings.DESCRIPTION) \ |
41 | | - .replace('{{admin_address}}', settings.ADMIN_ADDRESS) \ |
42 | | - .replace('{{keywords}}', settings.KEYWORDS) |
43 | | - |
44 | | - |
45 | | -@app.on_event('startup') |
46 | | -async def startup(): |
47 | | - # 初始化数据库 |
48 | | - await init_models() |
| 45 | +admin_html = open('templates/admin.html', 'r', encoding='utf-8').read() |
49 | 46 |
|
50 | 47 |
|
51 | 48 | @app.get('/') |
52 | 49 | async def index(): |
53 | | - return HTMLResponse(index_html) |
| 50 | + return HTMLResponse( |
| 51 | + index_html.replace('{{title}}', settings.TITLE).replace('{{description}}', settings.DESCRIPTION).replace( |
| 52 | + '{{keywords}}', settings.KEYWORDS).replace("'{{fileSizeLimit}}'", str(settings.FILE_SIZE_LIMIT)) |
| 53 | + ) |
54 | 54 |
|
55 | 55 |
|
56 | 56 | @app.get(f'/{settings.ADMIN_ADDRESS}', description='管理页面') |
57 | 57 | async def admin(): |
58 | | - return HTMLResponse(admin_html) |
| 58 | + return HTMLResponse( |
| 59 | + admin_html.replace('{{title}}', settings.TITLE).replace('{{description}}', settings.DESCRIPTION).replace( |
| 60 | + '{{admin_address}}', settings.ADMIN_ADDRESS).replace('{{keywords}}', settings.KEYWORDS) |
| 61 | + ) |
59 | 62 |
|
60 | 63 |
|
61 | 64 | @app.get(f'/{settings.ADMIN_ADDRESS}/files', dependencies=[Depends(admin_required)]) |
@@ -150,14 +153,10 @@ async def admin_patch(request: Request, s: AsyncSession = Depends(get_session)): |
150 | 153 | await s.execute(update(Options).where(Options.key == key).values(value=value)) |
151 | 154 | await settings.update(key, value) |
152 | 155 | await s.commit() |
| 156 | + await settings.updates([[i.id, i.key, i.value] for i in (await s.execute(select(Options))).scalars().all()]) |
153 | 157 | return {'detail': '修改成功'} |
154 | 158 |
|
155 | 159 |
|
156 | | -@app.get('/') |
157 | | -async def index(): |
158 | | - return HTMLResponse(index_html) |
159 | | - |
160 | | - |
161 | 160 | @app.post('/') |
162 | 161 | async def index(code: str, ip: str = Depends(error_ip_limit), s: AsyncSession = Depends(get_session)): |
163 | 162 | query = select(Codes).where(Codes.code == code) |
|
0 commit comments