import asyncio import logging import openai from fastapi import FastAPI, Request from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from openai import OpenAI from pydantic import BaseModel from pymysql import OperationalError from starlette.middleware.cors import CORSMiddleware from LocalModel import CustomLogin, SaveUser, QueryUser, DeleteUser from logic import * from model import CustomUser, UserInfo, ZaneTest, database API_KEY = "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" app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) logging.basicConfig( level=logging.INFO, # 设置日志级别 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # 日志格式 datefmt='%Y-%m-%d %H:%M:%S', # 时间格式 filename='app.log', # 日志文件存储位置 filemode='a' # 文件模式,'a'为追加模式,默认为'a',还可以选择'w'覆写模式 ) def check_db_connect(): try: database.connect(reuse_if_open=True) dt = CustomUser.select() for item in dt: print(item.id) except OperationalError as e: if 'MySQL server has gone away' in str(e): database.close() database.connect() logging.info("reconnect database") # threading.Timer(60 * 60, check_db_connect).start() # check_db_connect() class Question(BaseModel): user: str content: str stream: bool = True async def ai_stream(content: str): client = openai.OpenAI(api_key=API_KEY) completion = client.chat.completions.create( model="gpt-3.5-turbo", stream=True, messages=[ {"role": "system", "content": content} ] ) try: for chunk in completion: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content + "\n" await asyncio.sleep(0.01) # 稍微暂停以允许其他任务执行 except Exception as e: yield f"Error: {e}\n" def ai_normal(content: str): client = OpenAI(api_key=API_KEY) completion = client.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": content}, ] ) return {"msg": completion.choices[0].message.content} # @app.post("/ai/") # async def do_ai(question: Question): # if question.stream: # return StreamingResponse(ai_stream(question.content), media_type="text/event-stream") # else: # return ai_normal(question.content) # def get_value(s: str): # return s # class MyRequest(BaseModel): # content: str # def test_func(arg1: str): # print(arg1) # return "nice" # @app.post("/func/") # async def call_func(mq: MyRequest): # client = OpenAI(api_key=API_KEY) # messages = [] # messages.append({"role": "system", # "content": "You are a helpful assistant"}) # messages.append({"role": "system", # "content": "If you need to call a function but you do not have enough parameters, ask the user to provide you with the missing parameters."}) # messages.append({"role": "system", # "content": "You must answer in Chinese"}) # messages.append({"role": "user", "content": mq.content}) # # tools = [{ # "type": "function", # "function": { # "name": "get_user_birthday", # "description": "get user birthday", # "parameters": { # "type": "object", # "properties": { # "birthday": { # "type": "string", # "description": "user birthday" # }, # "city": { # "type": "string", # "description": "city of user born" # } # }, # "required": ["birthday", "city"], # } # } # }] # # completion1 = client.chat.completions.create( # model="gpt-4", # messages=messages, # tools=tools # ) # ast1 = completion1.choices[0].message # return {"msg": ast1} # # class YearInfo(BaseModel): # year: int # month: int # day: int # hour: int # minute: int # # @app.post("/wnl/add/") # async def add_wnl(info: YearInfo): # # result = [] # # for year in range(info.from_year, info.to_year+1): # # for month in range(1, 13): # # max_day = 30 # # if month == 2: # # if year % 4 == 0: # # max_day = 29 # # else: # # max_day = 28 # # elif month in (1, 3, 5, 7, 8, 10, 12): # # max_day = 31 # # for day in range(1, max_day + 1): # # result.append({ # # "nian": year, # # "yue": month, # # "ri": day # # }) # # Wannianli.insert_many(result).execute() # wnl = Wannianli.select() # ct = len(wnl) # return {"data": "新增了" + str(ct) + "条数据"} # # # @app.post("/wnl/update/") # async def update_wnl(info: YearInfo): # data = get_wannianli_data(info.year, info.month, info.day) # msg = [] # if data is not None: # msg = [data.nian_gan, data.nian_zhi, # data.yue_gan, data.yue_zhi, # data.ri_gan, data.ri_zhi] # hour_data = get_hour_of_day(data.ri_gan, info.hour) # msg.append(hour_data[0]) # msg.append(hour_data[1]) # return {"date": str(info.year) + "-" + str(info.month) + "-" + str(info.day) + " " + str(info.hour) + ":" + str( # info.minute), # "msg": msg} @app.post("/api/getSiZhuInfo") async def getSiZhuInfo(request: SiZhuInfoRequest): startDtm = None if request.mode == 2: startDtm = calc_date_of_sizhu(request) bazi = BaZi(request) dc = DataCenter(bazi) if startDtm is not None: bazi.taiyangshi = startDtm.__str__() fill_sizhu_in_bazi(bazi, dc) # logging.info("this is a info") # logging.info(jsonable_encoder(bazi)) # print(jsonable_encoder(bazi)) return jsonable_encoder(bazi) @app.post("/api/customLogin") async def customLogin(request: CustomLogin): logging.info("login") dt = CustomUser.select().where(CustomUser.user == request.user, CustomUser.psd == request.psd).first() if dt is not None: return {"msg": "ok", "name": dt.name, "sexy": dt.sexy} else: return {"msg": "error", "name": None, "sexy": None} @app.post("/api/saveUser") async def saveUser(request: SaveUser): ct = UserInfo.select().where(UserInfo.customer == request.customer).count() if ct >= 100: return {"msg": "超过可以保存的用户上限,请联系管理员", "state": -1} UserInfo.insert(request.to_db_data()).execute() return {"msg": "保存用户信息成功", "state": 200} def __build_user_object(dt: UserInfo): return { "id": dt.id, "name": dt.name, "beizhu": dt.beizhu, "isMan": bool(dt.man), "leibie": dt.leibie, "year": dt.year, "month": dt.month, "day": dt.day, "hour": dt.hour, "minute": dt.minute, "sheng": dt.sheng, "shi": dt.shi, "qu": dt.qu, "niangan": dt.niangan, "nianzhi": dt.nianzhi, "yuegan": dt.yuegan, "yuezhi": dt.yuezhi, "rigan": dt.rigan, "rizhi": dt.rizhi, "shigan": dt.shigan, "shizhi": dt.shizhi, "customer": dt.customer, "joinTime": dt.join_time } def __do_query_user(customer: str, filter: str): dts = UserInfo.select().where(UserInfo.customer == customer, UserInfo.enabled == 1) data = [] if len(dts) > 0: for dt in dts: if filter is None: data.append(__build_user_object(dt)) else: if filter in dt.name: data.append(__build_user_object(dt)) return data @app.post("/api/queryUser") async def queryUser(request: QueryUser): data = __do_query_user(request.customer, request.filter) return jsonable_encoder(data) @app.post("/api/deleteUser") async def deleteUser(request: DeleteUser): UserInfo.update({"enabled": 0}).where(UserInfo.id == request.id).execute() return __do_query_user(request.customer, None) @app.post("/api/test") async def test(request: Request): request_origin = request.headers.get('origin') if request_origin is None: request_origin = "unknown" content = {"message": "Hello World" + request_origin, "db": "disconnect!!!"} headers = {'Access-Control-Allow-Origin': request_origin} content["db"] = "is_closed: " + \ str(database.is_closed()) + " is_usable:" + \ str(database.is_connection_usable()) try: dt = CustomUser.select() for item in dt: print(item.name) except Exception as e: logging.info(e) return JSONResponse(content=content, headers=headers) @app.get("/api/test2") async def test2(): users = ZaneTest.select().where(ZaneTest.a != None) users = [__build_user_object(dt) for dt in users] return {"message": "Hello World", "users": users}