forked from vastsa/FileCodeBox
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
149 lines (127 loc) · 4.68 KB
/
main.py
File metadata and controls
149 lines (127 loc) · 4.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import datetime
import os
import uuid
from fastapi import FastAPI, Depends, UploadFile, Form, File
from sqlalchemy.orm import Session
from starlette.requests import Request
from starlette.responses import HTMLResponse
import random
from starlette.staticfiles import StaticFiles
import database
from database import engine, SessionLocal, Base
Base.metadata.create_all(bind=engine)
app = FastAPI()
if not os.path.exists('./static'):
os.makedirs('./static')
app.mount("/static", StaticFiles(directory="static"), name="static")
index_html = open('templates/index.html', 'r').read()
admin_html = open('templates/admin.html', 'r').read()
# 过期时间
exp_hour = 24
# 允许错误次数
error_count = 5
# 禁止分钟数
error_minute = 60
# 后台地址
admin_address = 'admin'
# 管理密码
admin_password = 'admin'
error_ip_count = {}
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_code(db: Session = Depends(get_db)):
code = random.randint(10000, 99999)
while db.query(database.Codes).filter(database.Codes.code == code).first():
code = random.randint(10000, 99999)
return str(code)
def get_file_name(key, ext, file):
now = datetime.datetime.now()
path = f'./static/upload/{now.year}/{now.month}/{now.day}/'
name = f'{key}.{ext}'
if not os.path.exists(path):
os.makedirs(path)
file = file.file.read()
with open(f'{os.path.join(path, name)}', 'wb') as f:
f.write(file)
return key, len(file), path[1:] + name
@app.get(f'/{admin_address}')
async def admin():
return HTMLResponse(admin_html)
@app.post(f'/{admin_address}')
async def admin_post(request: Request, db: Session = Depends(get_db)):
if request.headers.get('pwd') == admin_password:
codes = db.query(database.Codes).all()
return {'code': 200, 'msg': '查询成功', 'data': codes}
else:
return {'code': 400, 'msg': '密码错误'}
@app.delete(f'/{admin_address}')
async def admin_delete(request: Request, code: str, db: Session = Depends(get_db)):
if request.headers.get('pwd') == admin_password:
file = db.query(database.Codes).filter(database.Codes.code == code)
if file.first().type != 'text/plain':
os.remove('.' + file.first().text)
file.delete()
db.commit()
return {'code': 200, 'msg': '删除成功'}
else:
return {'code': 400, 'msg': '密码错误'}
@app.get('/')
async def index():
return HTMLResponse(index_html)
@app.post('/')
async def index(request: Request, code: str, db: Session = Depends(get_db)):
info = db.query(database.Codes).filter(database.Codes.code == code).first()
error = error_ip_count.get(request.client.host, {'count': 0, 'time': datetime.datetime.now()})
if error['count'] > error_count:
if datetime.datetime.now() - error['time'] < datetime.timedelta(minutes=error_minute):
return {'code': 404, 'msg': '请求过于频繁,请稍后再试'}
else:
error['count'] = 0
else:
if not info:
error['count'] += 1
error_ip_count[request.client.host] = error
return {'code': 404, 'msg': f'取件码错误,错误5次将被禁止10分钟'}
else:
return {'code': 200, 'msg': '取件成功,请点击“取”查看', 'data': info}
@app.post('/share')
async def share(text: str = Form(default=None), file: UploadFile = File(default=None), db: Session = Depends(get_db)):
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=exp_hour)
db.query(database.Codes).filter(database.Codes.use_time < cutoff_time).delete()
db.commit()
code = get_code(db)
if text:
info = database.Codes(
code=code,
text=text,
type='text/plain',
key=uuid.uuid4().hex,
size=len(text),
used=True,
name='分享文本'
)
db.add(info)
db.commit()
return {'code': 200, 'msg': '上传成功,请点击文件库查看',
'data': {'code': code, 'name': '分享文本', 'text': text}}
elif file:
key, size, full_path = get_file_name(uuid.uuid4().hex, file.filename.split('.')[-1], file)
info = database.Codes(
code=code,
text=full_path,
type=file.content_type,
key=key,
size=size,
used=True,
name=file.filename
)
db.add(info)
db.commit()
return {'code': 200, 'msg': '上传成功,请点击文件库查看',
'data': {'code': code, 'name': file.filename, 'text': full_path}}
else:
return {'code': 422, 'msg': '参数错误', 'data': []}