main.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import asyncio
  2. import logging
  3. import openai
  4. from fastapi import FastAPI, Request
  5. from fastapi.encoders import jsonable_encoder
  6. from fastapi.responses import JSONResponse
  7. from openai import OpenAI
  8. from pydantic import BaseModel
  9. from pymysql import OperationalError
  10. from starlette.middleware.cors import CORSMiddleware
  11. from LocalModel import CustomLogin, SaveUser, QueryUser, DeleteUser
  12. from logic import *
  13. from model import CustomUser, UserInfo, ZaneTest, database
  14. API_KEY = "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"
  15. app = FastAPI()
  16. app.add_middleware(
  17. CORSMiddleware,
  18. allow_origins=["*"],
  19. allow_credentials=True,
  20. allow_methods=["*"],
  21. allow_headers=["*"],
  22. )
  23. logging.basicConfig(
  24. level=logging.INFO, # 设置日志级别
  25. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # 日志格式
  26. datefmt='%Y-%m-%d %H:%M:%S', # 时间格式
  27. filename='app.log', # 日志文件存储位置
  28. filemode='a' # 文件模式,'a'为追加模式,默认为'a',还可以选择'w'覆写模式
  29. )
  30. def check_db_connect():
  31. try:
  32. database.connect(reuse_if_open=True)
  33. dt = CustomUser.select()
  34. for item in dt:
  35. print(item.id)
  36. except OperationalError as e:
  37. if 'MySQL server has gone away' in str(e):
  38. database.close()
  39. database.connect()
  40. logging.info("reconnect database")
  41. # threading.Timer(60 * 60, check_db_connect).start()
  42. # check_db_connect()
  43. class Question(BaseModel):
  44. user: str
  45. content: str
  46. stream: bool = True
  47. async def ai_stream(content: str):
  48. client = openai.OpenAI(api_key=API_KEY)
  49. completion = client.chat.completions.create(
  50. model="gpt-3.5-turbo",
  51. stream=True,
  52. messages=[
  53. {"role": "system", "content": content}
  54. ]
  55. )
  56. try:
  57. for chunk in completion:
  58. if chunk.choices[0].delta.content:
  59. yield chunk.choices[0].delta.content + "\n"
  60. await asyncio.sleep(0.01) # 稍微暂停以允许其他任务执行
  61. except Exception as e:
  62. yield f"Error: {e}\n"
  63. def ai_normal(content: str):
  64. client = OpenAI(api_key=API_KEY)
  65. completion = client.chat.completions.create(
  66. model="gpt-3.5-turbo",
  67. messages=[
  68. {"role": "system",
  69. "content": content},
  70. ]
  71. )
  72. return {"msg": completion.choices[0].message.content}
  73. # @app.post("/ai/")
  74. # async def do_ai(question: Question):
  75. # if question.stream:
  76. # return StreamingResponse(ai_stream(question.content), media_type="text/event-stream")
  77. # else:
  78. # return ai_normal(question.content)
  79. # def get_value(s: str):
  80. # return s
  81. # class MyRequest(BaseModel):
  82. # content: str
  83. # def test_func(arg1: str):
  84. # print(arg1)
  85. # return "nice"
  86. # @app.post("/func/")
  87. # async def call_func(mq: MyRequest):
  88. # client = OpenAI(api_key=API_KEY)
  89. # messages = []
  90. # messages.append({"role": "system",
  91. # "content": "You are a helpful assistant"})
  92. # messages.append({"role": "system",
  93. # "content": "If you need to call a function but you do not have enough parameters, ask the user to provide you with the missing parameters."})
  94. # messages.append({"role": "system",
  95. # "content": "You must answer in Chinese"})
  96. # messages.append({"role": "user", "content": mq.content})
  97. #
  98. # tools = [{
  99. # "type": "function",
  100. # "function": {
  101. # "name": "get_user_birthday",
  102. # "description": "get user birthday",
  103. # "parameters": {
  104. # "type": "object",
  105. # "properties": {
  106. # "birthday": {
  107. # "type": "string",
  108. # "description": "user birthday"
  109. # },
  110. # "city": {
  111. # "type": "string",
  112. # "description": "city of user born"
  113. # }
  114. # },
  115. # "required": ["birthday", "city"],
  116. # }
  117. # }
  118. # }]
  119. #
  120. # completion1 = client.chat.completions.create(
  121. # model="gpt-4",
  122. # messages=messages,
  123. # tools=tools
  124. # )
  125. # ast1 = completion1.choices[0].message
  126. # return {"msg": ast1}
  127. #
  128. # class YearInfo(BaseModel):
  129. # year: int
  130. # month: int
  131. # day: int
  132. # hour: int
  133. # minute: int
  134. #
  135. # @app.post("/wnl/add/")
  136. # async def add_wnl(info: YearInfo):
  137. # # result = []
  138. # # for year in range(info.from_year, info.to_year+1):
  139. # # for month in range(1, 13):
  140. # # max_day = 30
  141. # # if month == 2:
  142. # # if year % 4 == 0:
  143. # # max_day = 29
  144. # # else:
  145. # # max_day = 28
  146. # # elif month in (1, 3, 5, 7, 8, 10, 12):
  147. # # max_day = 31
  148. # # for day in range(1, max_day + 1):
  149. # # result.append({
  150. # # "nian": year,
  151. # # "yue": month,
  152. # # "ri": day
  153. # # })
  154. # # Wannianli.insert_many(result).execute()
  155. # wnl = Wannianli.select()
  156. # ct = len(wnl)
  157. # return {"data": "新增了" + str(ct) + "条数据"}
  158. #
  159. #
  160. # @app.post("/wnl/update/")
  161. # async def update_wnl(info: YearInfo):
  162. # data = get_wannianli_data(info.year, info.month, info.day)
  163. # msg = []
  164. # if data is not None:
  165. # msg = [data.nian_gan, data.nian_zhi,
  166. # data.yue_gan, data.yue_zhi,
  167. # data.ri_gan, data.ri_zhi]
  168. # hour_data = get_hour_of_day(data.ri_gan, info.hour)
  169. # msg.append(hour_data[0])
  170. # msg.append(hour_data[1])
  171. # return {"date": str(info.year) + "-" + str(info.month) + "-" + str(info.day) + " " + str(info.hour) + ":" + str(
  172. # info.minute),
  173. # "msg": msg}
  174. @app.post("/api/getSiZhuInfo")
  175. async def getSiZhuInfo(request: SiZhuInfoRequest):
  176. startDtm = None
  177. if request.mode == 2:
  178. startDtm = calc_date_of_sizhu(request)
  179. bazi = BaZi(request)
  180. dc = DataCenter(bazi)
  181. if startDtm is not None:
  182. bazi.taiyangshi = startDtm.__str__()
  183. fill_sizhu_in_bazi(bazi, dc)
  184. # logging.info("this is a info")
  185. # logging.info(jsonable_encoder(bazi))
  186. # print(jsonable_encoder(bazi))
  187. return jsonable_encoder(bazi)
  188. @app.post("/api/customLogin")
  189. async def customLogin(request: CustomLogin):
  190. logging.info("login")
  191. dt = CustomUser.select().where(CustomUser.user == request.user,
  192. CustomUser.psd == request.psd).first()
  193. if dt is not None:
  194. return {"msg": "ok", "name": dt.name, "sexy": dt.sexy}
  195. else:
  196. return {"msg": "error", "name": None, "sexy": None}
  197. @app.post("/api/saveUser")
  198. async def saveUser(request: SaveUser):
  199. ct = UserInfo.select().where(UserInfo.customer == request.customer).count()
  200. if ct >= 100:
  201. return {"msg": "超过可以保存的用户上限,请联系管理员", "state": -1}
  202. UserInfo.insert(request.to_db_data()).execute()
  203. return {"msg": "保存用户信息成功", "state": 200}
  204. def __build_user_object(dt: UserInfo):
  205. return {
  206. "id": dt.id,
  207. "name": dt.name,
  208. "beizhu": dt.beizhu,
  209. "isMan": bool(dt.man),
  210. "leibie": dt.leibie,
  211. "year": dt.year,
  212. "month": dt.month,
  213. "day": dt.day,
  214. "hour": dt.hour,
  215. "minute": dt.minute,
  216. "sheng": dt.sheng,
  217. "shi": dt.shi,
  218. "qu": dt.qu,
  219. "niangan": dt.niangan,
  220. "nianzhi": dt.nianzhi,
  221. "yuegan": dt.yuegan,
  222. "yuezhi": dt.yuezhi,
  223. "rigan": dt.rigan,
  224. "rizhi": dt.rizhi,
  225. "shigan": dt.shigan,
  226. "shizhi": dt.shizhi,
  227. "customer": dt.customer,
  228. "joinTime": dt.join_time
  229. }
  230. def __do_query_user(customer: str, filter: str):
  231. dts = UserInfo.select().where(UserInfo.customer == customer, UserInfo.enabled == 1)
  232. data = []
  233. if len(dts) > 0:
  234. for dt in dts:
  235. if filter is None:
  236. data.append(__build_user_object(dt))
  237. else:
  238. if filter in dt.name:
  239. data.append(__build_user_object(dt))
  240. return data
  241. @app.post("/api/queryUser")
  242. async def queryUser(request: QueryUser):
  243. data = __do_query_user(request.customer, request.filter)
  244. return jsonable_encoder(data)
  245. @app.post("/api/deleteUser")
  246. async def deleteUser(request: DeleteUser):
  247. UserInfo.update({"enabled": 0}).where(UserInfo.id == request.id).execute()
  248. return __do_query_user(request.customer, None)
  249. @app.post("/api/test")
  250. async def test(request: Request):
  251. request_origin = request.headers.get('origin')
  252. if request_origin is None:
  253. request_origin = "unknown"
  254. content = {"message": "Hello World" +
  255. request_origin, "db": "disconnect!!!"}
  256. headers = {'Access-Control-Allow-Origin': request_origin}
  257. content["db"] = "is_closed: " + \
  258. str(database.is_closed()) + " is_usable:" + \
  259. str(database.is_connection_usable())
  260. try:
  261. dt = CustomUser.select()
  262. for item in dt:
  263. print(item.name)
  264. except Exception as e:
  265. logging.info(e)
  266. return JSONResponse(content=content, headers=headers)
  267. @app.get("/api/test2")
  268. async def test2():
  269. users = ZaneTest.select().where(ZaneTest.a != None)
  270. users = [__build_user_object(dt) for dt in users]
  271. return {"message": "Hello World", "users": users}