diff --git a/api/core.py b/api/core.py index 6dd6381..f43d71c 100644 --- a/api/core.py +++ b/api/core.py @@ -3,6 +3,8 @@ import os import sys +from helpers import errors + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(project_root) @@ -24,19 +26,17 @@ load_dotenv() router = fastapi.APIRouter(tags=['core']) async def check_core_auth(request): - """ - - ### Checks the request's auth - Auth is taken from environment variable `CORE_API_KEY` - + """Checks the core API key. Returns nothing if it's valid, otherwise returns an error. """ received_auth = request.headers.get('Authorization') correct_core_api = os.environ['CORE_API_KEY'] # use hmac.compare_digest to prevent timing attacks - if received_auth and hmac.compare_digest(received_auth, correct_core_api): - return fastapi.Response(status_code=403, content='Invalid or no API key given.') + if not (received_auth and hmac.compare_digest(received_auth, correct_core_api)): + return await errors.error(401, 'The core API key you provided is invalid.', 'Check the `Authorization` header.') + + return None @router.get('/users') async def get_users(discord_id: int, incoming_request: fastapi.Request): @@ -50,7 +50,7 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request): manager = UserManager() user = await manager.user_by_discord_id(discord_id) if not user: - return fastapi.Response(status_code=404, content='User not found.') + return await errors.error(404, 'Discord user not found in the API database.', 'Check the `discord_id` parameter.') return user @@ -83,7 +83,7 @@ async def create_user(incoming_request: fastapi.Request): payload = await incoming_request.json() discord_id = payload.get('discord_id') except (json.decoder.JSONDecodeError, AttributeError): - return fastapi.Response(status_code=400, content='Invalid or no payload received.') + return await errors.error(400, 'Invalid or no payload received.', 'The payload must be a JSON object with a `discord_id` key.') # Create the user manager = UserManager() @@ -106,9 +106,12 @@ async def update_user(incoming_request: fastapi.Request): discord_id = payload.get('discord_id') updates = payload.get('updates') except (json.decoder.JSONDecodeError, AttributeError): - return fastapi.Response(status_code=400, content='Invalid or no payload received.') + return await errors.error( + 400, 'Invalid or no payload received.', + 'The payload must be a JSON object with a `discord_id` key and an `updates` key.' + ) - # Update the user + # Update the user manager = UserManager() user = await manager.update_by_discord_id(discord_id, updates) @@ -123,9 +126,23 @@ async def run_checks(incoming_request: fastapi.Request): if auth_error: return auth_error + try: + chat = await checks.client.test_chat() + except Exception: + chat = None + + try: + moderation = await checks.client.test_api_moderation() + except Exception: + moderation = None + + try: + models = await checks.client.test_models() + except Exception: + models = None + return { - 'library': await checks.client.test_library(), - 'library_moderation': await checks.client.test_library_moderation(), - 'api_moderation': await checks.client.test_api_moderation(), - 'models': await checks.client.test_models() + 'chat/completions': chat, + 'models': models, + 'moderations': moderation, } diff --git a/api/main.py b/api/main.py index b157c2a..85139c4 100644 --- a/api/main.py +++ b/api/main.py @@ -36,7 +36,8 @@ async def root(): 'hi': 'Welcome to the Nova API!', 'learn_more_here': 'https://nova-oss.com', 'github': 'https://github.com/novaoss/nova-api', - 'core_api_docs_for_nova_developers': '/docs' + 'core_api_docs_for_nova_developers': '/docs', + 'ping': 'pong' } app.add_route('/v1/{path:path}', transfer.handle, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']) diff --git a/api/moderation.py b/api/moderation.py index dd65872..bcbff1d 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -1,5 +1,6 @@ """This module contains functions for checking if a message violates the moderation policy.""" +import time import asyncio import aiohttp @@ -29,6 +30,8 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: else: text = '\n'.join(inp) + print(f'[i] checking moderation for {text}') + for _ in range(3): req = await load_balancing.balance_organic_request( { @@ -36,9 +39,11 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: 'payload': {'input': text} } ) + print(f'[i] moderation request sent to {req["url"]}') async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session: try: + start = time.perf_counter() async with session.request( method=req.get('method', 'POST'), url=req['url'], @@ -51,14 +56,19 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: ) as res: res.raise_for_status() json_response = await res.json() + print(json_response) + categories = json_response['results'][0]['category_scores'] + print(f'[i] moderation check took {time.perf_counter() - start:.2f}s') + if json_response['results'][0]['flagged']: return max(categories, key=categories.get) return False except Exception as exc: + if '401' in str(exc): await provider_auth.invalidate_key(req.get('provider_auth')) print('[!] moderation error:', type(exc), exc) diff --git a/api/transfer.py b/api/transfer.py index 3a03524..51d3bc5 100644 --- a/api/transfer.py +++ b/api/transfer.py @@ -26,7 +26,7 @@ async def handle(incoming_request): Checks method, token amount, auth and cost along with if request is NSFW. """ users = UserManager() - path = incoming_request.url.path + path = incoming_request.url.path.replace('v1/v1', 'v1').replace('//', '/') if '/models' in path: return fastapi.responses.JSONResponse(content=models_list) @@ -62,10 +62,11 @@ async def handle(incoming_request): cost = costs['chat-models'].get(payload.get('model'), cost) policy_violation = False - if 'chat/completions' in path or ('input' in payload or 'prompt' in payload): - inp = payload.get('input', payload.get('prompt', '')) - if inp and len(inp) > 2 and not inp.isnumeric(): - policy_violation = await moderation.is_policy_violated(inp) + if '/moderations' not in path: + if '/chat/completions' in path or ('input' in payload or 'prompt' in payload): + inp = payload.get('input', payload.get('prompt', '')) + if inp and len(inp) > 2 and not inp.isnumeric(): + policy_violation = await moderation.is_policy_violated(inp) if policy_violation: return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') diff --git a/checks/__main__.py b/checks/__main__.py index b6ed337..d4b1290 100644 --- a/checks/__main__.py +++ b/checks/__main__.py @@ -1,2 +1,4 @@ import client -client.demo() +import asyncio + +asyncio.run(client.demo()) diff --git a/checks/client.py b/checks/client.py index 8d6622b..3702fbe 100644 --- a/checks/client.py +++ b/checks/client.py @@ -22,82 +22,103 @@ MESSAGES = [ } ] -api_endpoint = 'http://localhost:2332' +api_endpoint = 'http://localhost:2332/v1' async def test_server(): """Tests if the API server is running.""" try: - return httpx.get(f'{api_endpoint.replace("/v1", "")}').json()['status'] == 'ok' + request_start = time.perf_counter() + async with httpx.AsyncClient() as client: + response = await client.get( + url=f'{api_endpoint.replace("/v1", "")}', + timeout=3 + ) + response.raise_for_status() + + assert response.json()['ping'] == 'pong', 'The API did not return a correct response.' except httpx.ConnectError as exc: raise ConnectionError(f'API is not running on port {api_endpoint}.') from exc -async def test_api(model: str=MODEL, messages: List[dict]=None) -> dict: + else: + return time.perf_counter() - request_start + +async def test_chat(model: str=MODEL, messages: List[dict]=None) -> dict: """Tests an API api_endpoint.""" json_data = { 'model': model, 'messages': messages or MESSAGES, - 'stream': True, + 'stream': False } - response = httpx.post( - url=f'{api_endpoint}/chat/completions', - headers=HEADERS, - json=json_data, - timeout=20 - ) - response.raise_for_status() + request_start = time.perf_counter() - return response.text + async with httpx.AsyncClient() as client: + response = await client.post( + url=f'{api_endpoint}/chat/completions', + headers=HEADERS, + json=json_data, + timeout=10, + ) + response.raise_for_status() -async def test_library(): + assert '2' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.' + return time.perf_counter() - request_start + +async def test_library_chat(): """Tests if the api_endpoint is working with the OpenAI Python library.""" + request_start = time.perf_counter() completion = openai.ChatCompletion.create( model=MODEL, messages=MESSAGES ) - print(completion) - - return completion['choices'][0]['message']['content'] - -async def test_library_moderation(): - try: - return openai.Moderation.create('I wanna kill myself, I wanna kill myself; It\'s all I hear right now, it\'s all I hear right now') - except openai.error.InvalidRequestError: - return True + assert '2' in completion.choices[0]['message']['content'], 'The API did not return a correct response.' + return time.perf_counter() - request_start async def test_models(): - response = httpx.get( - url=f'{api_endpoint}/models', - headers=HEADERS, - timeout=5 - ) - response.raise_for_status() - return response.json() + """Tests the models endpoint.""" + + request_start = time.perf_counter() + async with httpx.AsyncClient() as client: + response = await client.get( + url=f'{api_endpoint}/models', + headers=HEADERS, + timeout=3 + ) + response.raise_for_status() + res = response.json() + + all_models = [model['id'] for model in res['data']] + + assert 'gpt-3.5-turbo' in all_models, 'The model gpt-3.5-turbo is not present in the models endpoint.' + return time.perf_counter() - request_start async def test_api_moderation() -> dict: - """Tests an API api_endpoint.""" + """Tests the moderation endpoint.""" - response = httpx.get( - url=f'{api_endpoint}/moderations', - headers=HEADERS, - timeout=20 - ) - response.raise_for_status() + request_start = time.perf_counter() + async with httpx.AsyncClient() as client: + response = await client.post( + url=f'{api_endpoint}/moderations', + headers=HEADERS, + timeout=5, + json={'input': 'fuck you, die'} + ) - return response.text + assert response.json()['results'][0]['flagged'] == True, 'Profanity not detected' + return time.perf_counter() - request_start # ========================================================================================== -def demo(): +async def demo(): """Runs all tests.""" try: for _ in range(30): - if test_server(): + if await test_server(): break print('Waiting until API Server is started up...') @@ -105,17 +126,17 @@ def demo(): else: raise ConnectionError('API Server is not running.') - print('[lightblue]Running a api endpoint to see if requests can go through...') - print(asyncio.run(test_api('gpt-3.5-turbo'))) + print('[lightblue]Checking if the API works...') + print(await test_chat()) - print('[lightblue]Checking if the API works with the python library...') - print(asyncio.run(test_library())) + print('[lightblue]Checking if the API works with the Python library...') + print(await test_library_chat()) print('[lightblue]Checking if the moderation endpoint works...') - print(asyncio.run(test_library_moderation())) + print(await test_api_moderation()) - print('[lightblue]Checking the /v1/models endpoint...') - print(asyncio.run(test_models())) + print('[lightblue]Checking the models endpoint...') + print(await test_models()) except Exception as exc: print('[red]Error: ' + str(exc)) @@ -131,4 +152,4 @@ HEADERS = { } if __name__ == '__main__': - demo() + asyncio.run(demo()) diff --git a/run/__main__.py b/run/__main__.py index f188ef0..ebd513b 100644 --- a/run/__main__.py +++ b/run/__main__.py @@ -23,4 +23,4 @@ if 'prod' in sys.argv: port = 2333 dev = False -os.system(f'cd api && uvicorn main:app{" --reload" if dev else ""} --host 0.0.0.0 --port {port} & python tests') +os.system(f'cd api && uvicorn main:app{" --reload" if dev else ""} --host 0.0.0.0 --port {port}')