forked from MeoProject/lx-music-api-server
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
362 lines (319 loc) · 14.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
#!/usr/bin/env python3
# ----------------------------------------
# - mode: python -
# - author: helloplhm-qwq -
# - name: main.py -
# - project: lx-music-api-server -
# - license: MIT -
# ----------------------------------------
# This file is part of the "lx-music-api-server" project.
import time
import aiohttp
import asyncio
import traceback
import threading
import ujson as json
from aiohttp.web import Response, FileResponse, StreamResponse, Application
from io import TextIOWrapper
import sys
import os
if ((sys.version_info.major == 3 and sys.version_info.minor < 6) or sys.version_info.major == 2):
print('Python版本过低,请使用Python 3.6+ ')
sys.exit(1)
# fix: module not found: common/modules
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from common import utils
from common import config, localMusic
from common import lxsecurity
from common import log
from common import Httpx
from common import variable
from common import scheduler
from common import lx_script
from common import gcsp
import modules
def handleResult(dic, status=200) -> Response:
if (not isinstance(dic, dict)):
dic = {
'code': 0,
'msg': 'success',
'data': dic
}
return Response(body=json.dumps(dic, indent=2, ensure_ascii=False), content_type='application/json', status=status)
logger = log.log("main")
aiologger = log.log('aiohttp_web')
stopEvent = None
if (sys.version_info.minor < 8 and sys.version_info.major == 3):
logger.warning('您使用的Python版本已经停止更新,不建议继续使用')
import concurrent
stopEvent = concurrent.futures._base.CancelledError
else:
stopEvent = asyncio.exceptions.CancelledError
def start_checkcn_thread() -> None:
threading.Thread(target=Httpx.checkcn).start()
# check request info before start
async def handle_before_request(app, handler):
async def handle_request(request):
try:
if (config.read_config('common.reverse_proxy.allow_proxy')):
if (request.headers.get(config.read_config('common.reverse_proxy.real_ip_header'))):
# proxy header
if (config.read_config('common.reverse_proxy.allow_public_ip') or utils.is_local_ip(request.remote)):
request.remote_addr = request.headers.get(
config.read_config('common.reverse_proxy.real_ip_header'))
else:
return handleResult({"code": 1, "msg": "不允许的公网ip转发", "data": None}, 403)
else:
request.remote_addr = request.remote
else:
request.remote_addr = request.remote
# check ip
if (config.check_ip_banned(request.remote_addr)):
return handleResult({"code": 1, "msg": "您的IP已被封禁", "data": None}, 403)
# check global rate limit
if (
(time.time() - config.getRequestTime('global'))
<
(config.read_config("security.rate_limit.global"))
):
return handleResult({"code": 5, "msg": "全局限速", "data": None}, 429)
if (
(time.time() - config.getRequestTime(request.remote_addr))
<
(config.read_config("security.rate_limit.ip"))
):
return handleResult({"code": 5, "msg": "IP限速", "data": None}, 429)
# update request time
config.updateRequestTime('global')
config.updateRequestTime(request.remote_addr)
# check host
if (config.read_config("security.allowed_host.enable")):
if request.host.split(":")[0] not in config.read_config("security.allowed_host.list"):
if config.read_config("security.allowed_host.blacklist.enable"):
config.ban_ip(request.remote_addr, int(
config.read_config("security.allowed_host.blacklist.length")))
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
resp = await handler(request)
if (isinstance(resp, (str, list, dict))):
resp = handleResult(resp)
elif (isinstance(resp, tuple) and len(resp) == 2): # flask like response
body, status = resp
if (isinstance(body, (str, list, dict))):
resp = handleResult(body, status)
else:
resp = Response(
body=str(body), content_type='text/plain', status=status)
elif (not isinstance(resp, (Response, FileResponse, StreamResponse))):
resp = Response(
body=str(resp), content_type='text/plain', status=200)
aiologger.info(
f'{request.remote_addr + ("" if (request.remote == request.remote_addr) else f"|proxy@{request.remote}")} - {request.method} "{request.path}", {resp.status}')
return resp
except:
logger.error(traceback.format_exc())
return {"code": 4, "msg": "内部服务器错误", "data": None}
return handle_request
async def main(request):
return handleResult({"code": 0, "msg": "success", "data": None})
async def handle(request):
method = request.match_info.get('method')
source = request.match_info.get('source')
songId = request.match_info.get('songId')
quality = request.match_info.get('quality')
if (config.read_config("security.key.enable") and request.host.split(':')[0] not in config.read_config('security.whitelist_host')):
if (request.headers.get("X-Request-Key")) not in config.read_config("security.key.values"):
if (config.read_config("security.key.ban")):
config.ban_ip(request.remote_addr)
return handleResult({"code": 1, "msg": "key验证失败", "data": None}, 403)
if (config.read_config('security.check_lxm.enable') and request.host.split(':')[0] not in config.read_config('security.whitelist_host')):
lxm = request.headers.get('lxm')
if (not lxsecurity.checklxmheader(lxm, request.url)):
if (config.read_config('security.lxm_ban.enable')):
config.ban_ip(request.remote_addr)
return handleResult({"code": 1, "msg": "lxm请求头验证失败", "data": None}, 403)
try:
query = dict(request.query)
if (method in dir(modules)):
return handleResult(await getattr(modules, method)(source, songId, quality, query))
else:
return handleResult(await modules.other(method, source, songId, quality, query))
except:
logger.error(traceback.format_exc())
return handleResult({'code': 4, 'msg': '内部服务器错误', 'data': None}, 500)
async def handle_404(request):
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
async def handle_local(request):
try:
query = dict(request.query)
data = query.get('q')
data = utils.createBase64Decode(
data.replace('-', '+').replace('_', '/'))
data = json.loads(data)
t = request.match_info.get('type')
data['t'] = t
except:
logger.info(traceback.format_exc())
return handleResult({'code': 6, 'msg': '请求参数有错', 'data': None}, 404)
if (data['t'] == 'u'):
if (data['p'] in list(localMusic.map.keys())):
return await localMusic.generateAudioFileResonse(data['p'])
else:
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
if (data['t'] == 'l'):
if (data['p'] in list(localMusic.map.keys())):
return await localMusic.generateAudioLyricResponse(data['p'])
else:
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
if (data['t'] == 'p'):
if (data['p'] in list(localMusic.map.keys())):
return await localMusic.generateAudioCoverResonse(data['p'])
else:
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
if (data['t'] == 'c'):
if (not data['p'] in list(localMusic.map.keys())):
return {
'code': 0,
'msg': 'success',
'data': {
'file': False,
'cover': False,
'lyric': False
}
}
return {
'code': 0,
'msg': 'success',
'data': localMusic.checkLocalMusic(data['p'])
}
app = Application(middlewares=[handle_before_request])
utils.setGlobal(app, "app")
# mainpage
app.router.add_get('/', main)
# api
app.router.add_get('/{method}/{source}/{songId}/{quality}', handle)
app.router.add_get('/{method}/{source}/{songId}', handle)
app.router.add_get('/local/{type}', handle_local)
if (config.read_config('common.allow_download_script')):
app.router.add_get('/script', lx_script.generate_script_response)
if (config.read_config('module.gcsp.enable')):
app.router.add_route('*', config.read_config('module.gcsp.path'), gcsp.handle_request)
# 404
app.router.add_route('*', '/{tail:.*}', handle_404)
async def run_app_host(host):
retries = 0
while True:
if (retries > 4):
logger.warning("重试次数已达上限,但仍有部分端口未能完成监听,已自动进行忽略")
break
try:
ports = [int(port)
for port in config.read_config('common.ports')]
ssl_ports = [int(port) for port in config.read_config(
'common.ssl_info.ssl_ports')]
final_ssl_ports = []
final_ports = []
for p in ports:
if (p not in ssl_ports and f'{host}_{p}' not in variable.running_ports):
final_ports.append(p)
else:
if (p not in variable.running_ports):
final_ssl_ports.append(p)
# 读取证书和私钥路径
cert_path = config.read_config('common.ssl_info.path.cert')
privkey_path = config.read_config(
'common.ssl_info.path.privkey')
# 创建 HTTP AppRunner
http_runner = aiohttp.web.AppRunner(app)
await http_runner.setup()
# 启动 HTTP 端口监听
for port in final_ports:
if (port not in variable.running_ports):
http_site = aiohttp.web.TCPSite(
http_runner, host, port)
await http_site.start()
variable.running_ports.append(f'{host}_{port}')
logger.info(f"""监听 -> http://{
host if (':' not in host)
else '[' + host + ']'
}:{port}""")
if (config.read_config("common.ssl_info.enable") and final_ssl_ports != []):
if (os.path.exists(cert_path) and os.path.exists(privkey_path)):
import ssl
# 创建 SSL 上下文,加载配置文件中指定的证书和私钥
ssl_context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(cert_path, privkey_path)
# 创建 HTTPS AppRunner
https_runner = aiohttp.web.AppRunner(app)
await https_runner.setup()
# 启动 HTTPS 端口监听
for port in ssl_ports:
if (port not in variable.running_ports):
https_site = aiohttp.web.TCPSite(
https_runner, host, port, ssl_context=ssl_context)
await https_site.start()
variable.running_ports.append(f'{host}_{port}')
logger.info(f"""监听 -> https://{
host if (':' not in host)
else '[' + host + ']'
}:{port}""")
logger.debug(f"HOST({host}) 已完成监听")
break
except OSError as e:
if (str(e).startswith("[Errno 98]") or str(e).startswith('[Errno 10048]')):
logger.error("端口已被占用,请检查\n" + str(e))
logger.info('服务器将在10s后再次尝试启动...')
await asyncio.sleep(10)
logger.info('重新尝试启动...')
retries += 1
else:
logger.error("未知错误,请检查\n" + traceback.format_exc())
async def run_app():
for host in config.read_config('common.hosts'):
await run_app_host(host)
async def initMain():
await scheduler.run()
variable.aioSession = aiohttp.ClientSession(trust_env=True)
localMusic.initMain()
try:
await run_app()
logger.info("服务器启动成功,请按下Ctrl + C停止")
await asyncio.Event().wait() # 等待停止事件
except (KeyboardInterrupt, stopEvent):
pass
except OSError as e:
logger.error("遇到未知错误,请查看日志")
logger.error(traceback.format_exc())
except:
logger.error("遇到未知错误,请查看日志")
logger.error(traceback.format_exc())
finally:
logger.info('wating for sessions to complete...')
if variable.aioSession:
await variable.aioSession.close()
variable.running = False
logger.info("Server stopped")
if __name__ == "__main__":
try:
start_checkcn_thread()
asyncio.run(initMain())
except KeyboardInterrupt:
pass
except:
logger.critical('初始化出错,请检查日志')
logger.critical(traceback.format_exc())
with open('dumprecord_{}.txt'.format(int(time.time())), 'w', encoding='utf-8') as f:
f.write(traceback.format_exc())
e = '\n\nGlobal variable object:\n\n'
for k in dir(variable):
e += (k + ' = ' + str(getattr(variable, k)) + '\n') if (not k.startswith('_')) else ''
f.write(e)
e = '\n\nsys.modules:\n\n'
for k in sys.modules:
e += (k + ' = ' + str(sys.modules[k]) + '\n') if (not k.startswith('_')) else ''
f.write(e)
logger.critical('dumprecord_{}.txt 已保存至当前目录'.format(int(time.time())))
finally:
for f in variable.log_files:
if (f and isinstance(f, TextIOWrapper)):
f.close()