From c23bc7a5d352f699c06f8eba4b5ca30755ac4f5b Mon Sep 17 00:00:00 2001 From: nsde Date: Sun, 10 Sep 2023 16:22:46 +0200 Subject: [PATCH] Changed timeout to 500 --- .gitignore | 2 ++ api/moderation.py | 12 ++++++------ api/streaming.py | 2 +- checks/client.py | 27 ++++++--------------------- 4 files changed, 15 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index 08e41d1..29f5410 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +*.zip + last_update.txt *.log.json diff --git a/api/moderation.py b/api/moderation.py index c868af9..d587dd5 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -1,11 +1,13 @@ """This module contains functions for checking if a message violates the moderation policy.""" import time +import difflib import asyncio import aiocache import profanity_check from typing import Union +from Levenshtein import distance cache = aiocache.Cache(aiocache.SimpleMemoryCache) @@ -41,15 +43,13 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: async def is_policy_violated__own_model(inp: Union[str, list]) -> bool: """Checks if the input violates the moderation policy using our own model.""" - inp = input_to_text(inp) + inp = input_to_text(inp).lower() if profanity_check.predict([inp])[0]: - return 'NovaAI\'s selfhosted moderation model detected unsuitable content.' + return 'Sorry, our moderation AI has detected NSFW content in your message.' return False if __name__ == '__main__': - for i in range(10): - start = time.perf_counter() - print(asyncio.run(is_policy_violated('kill ms'))) - print((time.perf_counter() - start) * 1000) + while True: + print(asyncio.run(is_policy_violated(input('-> ')))) diff --git a/api/streaming.py b/api/streaming.py index 0cca780..117fea9 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -116,7 +116,7 @@ async def stream( ssl=False, timeout=aiohttp.ClientTimeout( connect=0.5, - total=float(os.getenv('TRANSFER_TIMEOUT', '120')) + total=float(os.getenv('TRANSFER_TIMEOUT', '500')) ), ) as response: diff --git a/checks/client.py b/checks/client.py index 7c232e9..5a29264 100644 --- a/checks/client.py +++ b/checks/client.py @@ -43,7 +43,7 @@ async def test_server(): else: return time.perf_counter() - request_start -async def test_chat(model: str=MODEL, messages: List[dict]=None) -> dict: +async def test_chat_non_stream(model: str=MODEL, messages: List[dict]=None) -> dict: """Tests an API api_endpoint.""" json_data = { @@ -107,21 +107,6 @@ async def test_models(): assert 'gpt-3.5-turbo' in all_models, 'The model gpt-3.5-turbo is not present in the models endpoint.' return time.perf_counter() - request_start -async def test_api_moderation() -> dict: - """Tests the moderation endpoint.""" - - request_start = time.perf_counter() - async with httpx.AsyncClient() as client: - response = await client.post( - url=f'{api_endpoint}/moderations', - headers=HEADERS, - timeout=5, - json={'input': 'fuck you, die'} - ) - - assert response.json()['results'][0]['flagged'] == True, 'Profanity not detected' - return time.perf_counter() - request_start - # ========================================================================================== async def demo(): @@ -137,16 +122,16 @@ async def demo(): else: raise ConnectionError('API Server is not running.') - print('[lightblue]Checking if the API works...') - print(await test_chat()) + print('Checking non-streamed chat completions...') + print(await test_chat_non_stream()) # print('[lightblue]Checking if SDXL image generation works...') # print(await test_sdxl()) - print('[lightblue]Checking if the moderation endpoint works...') - print(await test_api_moderation()) + # print('[lightblue]Checking if the moderation endpoint works...') + # print(await test_api_moderation()) - print('[lightblue]Checking the models endpoint...') + print('Checking the models endpoint...') print(await test_models()) except Exception as exc: