From 0b86f7c26c4e2ab0cc181c2d9a790eafd2a5447a Mon Sep 17 00:00:00 2001 From: nsde Date: Wed, 9 Aug 2023 11:15:49 +0200 Subject: [PATCH] Added /v1/models and fixed key invalidation --- api/db/users.py | 4 ++-- api/moderation.py | 3 ++- api/streaming.py | 34 ++++++++++++++++++++-------------- api/transfer.py | 2 -- rewards/last_update.txt | 2 +- tests/__main__.py | 31 ++++++++++++++++++------------- 6 files changed, 43 insertions(+), 33 deletions(-) diff --git a/api/db/users.py b/api/db/users.py index ffcbbd2..50c428c 100644 --- a/api/db/users.py +++ b/api/db/users.py @@ -15,7 +15,7 @@ with open('config/credits.yml', encoding='utf8') as f: async def _get_mongo(collection_name: str): return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name] -async def create(discord_id: int=0) -> dict: +async def create(discord_id: str='') -> dict: """Adds a new user to the MongoDB collection.""" chars = string.ascii_letters + string.digits @@ -36,7 +36,7 @@ async def create(discord_id: int=0) -> dict: 'ban_reason': '', }, 'auth': { - 'discord': discord_id, + 'discord': str(discord_id), 'github': None } } diff --git a/api/moderation.py b/api/moderation.py index 251dacb..d0c8144 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -47,7 +47,8 @@ async def is_policy_violated(inp) -> bool: return False except Exception as exc: - # await provider_auth.invalidate_key(req.get('provider_auth')) + if '401' in str(exc): + await provider_auth.invalidate_key(req.get('provider_auth')) print('[!] moderation error:', type(exc), exc) continue diff --git a/api/streaming.py b/api/streaming.py index aeb109a..27009ee 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -10,6 +10,7 @@ from dotenv import load_dotenv from python_socks._errors import ProxyError import proxies +import provider_auth import load_balancing from db import logs, users, stats @@ -36,8 +37,6 @@ async def stream( input_tokens: int=0, incoming_request: starlette.requests.Request=None, ): - - payload = payload or DEMO_PAYLOAD is_chat = False is_stream = payload.get('stream', False) @@ -45,7 +44,6 @@ async def stream( is_chat = True model = payload['model'] - if is_chat and is_stream: chat_id = await chat.create_chat_id() @@ -97,7 +95,10 @@ async def stream( return for k, v in target_request.get('headers', {}).items(): - headers[k] = v + target_request['headers'][k] = v + + if target_request['method'] == 'GET' and not payload: + target_request['payload'] = None async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session: try: @@ -108,24 +109,30 @@ async def stream( data=target_request.get('data'), json=target_request.get('payload'), - headers=headers, + headers=target_request.get('headers', {}), cookies=target_request.get('cookies'), ssl=False, timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))), ) as response: + if response.content_type == 'application/json': + data = await response.json() - if not is_stream: - json_response = await response.json() - - try: - response.raise_for_status() - except Exception as exc: - if 'Too Many Requests' in str(exc): + if data.get('code') == 'invalid_api_key': + await provider_auth.invalidate_key(target_request.get('provider_auth')) continue + if response.ok: + json_response = data + if is_stream: + try: + response.raise_for_status() + except Exception as exc: + if 'Too Many Requests' in str(exc): + continue + try: async for chunk in response.content.iter_any(): send = False @@ -180,10 +187,9 @@ async def stream( content=chat.CompletionStop ) yield chunk - yield 'data: [DONE]\n\n' - if not is_stream: + if not is_stream and json_response: yield json.dumps(json_response) # DONE ========================================================= diff --git a/api/transfer.py b/api/transfer.py index 153d866..7ed17e0 100644 --- a/api/transfer.py +++ b/api/transfer.py @@ -101,8 +101,6 @@ async def handle(incoming_request): # READY - # payload['user'] = str(user['_id']) - if 'chat/completions' in path and not payload.get('stream') is True: payload['stream'] = False diff --git a/rewards/last_update.txt b/rewards/last_update.txt index 3ad1aab..f0d3975 100644 --- a/rewards/last_update.txt +++ b/rewards/last_update.txt @@ -1 +1 @@ -1691460004.7354248 \ No newline at end of file +1691546405.042006 \ No newline at end of file diff --git a/tests/__main__.py b/tests/__main__.py index 0033ee7..3056715 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -40,11 +40,6 @@ def test_server(): def test_api(model: str=MODEL, messages: List[dict]=None) -> dict: """Tests an API api_endpoint.""" - headers = { - 'Content-Type': 'application/json', - 'Authorization': 'Bearer ' + closedai.api_key - } - json_data = { 'model': model, 'messages': messages or MESSAGES, @@ -53,7 +48,7 @@ def test_api(model: str=MODEL, messages: List[dict]=None) -> dict: response = httpx.post( url=f'{api_endpoint}/chat/completions', - headers=headers, + headers=HEADERS, json=json_data, timeout=20 ) @@ -74,25 +69,30 @@ def test_library(): def test_library_moderation(): return closedai.Moderation.create("I wanna kill myself, I wanna kill myself; It's all I hear right now, it's all I hear right now") +def test_models(): + response = httpx.get( + url=f'{api_endpoint}/models', + headers=HEADERS, + timeout=5 + ) + response.raise_for_status() + return response.json() + def test_all(): """Runs all tests.""" # print(test_server()) # print(test_api()) - print(test_library()) + # print(test_library()) # print(test_library_moderation()) - + print(test_models()) def test_api_moderation(model: str=MODEL, messages: List[dict]=None) -> dict: """Tests an API api_endpoint.""" - headers = { - 'Authorization': 'Bearer ' + closedai.api_key - } - response = httpx.get( url=f'{api_endpoint}/moderations', - headers=headers, + headers=HEADERS, timeout=20 ) response.raise_for_status() @@ -104,4 +104,9 @@ if __name__ == '__main__': closedai.api_base = api_endpoint closedai.api_key = os.getenv('TEST_NOVA_KEY') + HEADERS = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + closedai.api_key + } + test_all()