Changed timeout to 500

This commit is contained in:
nsde 2023-09-10 16:22:46 +02:00
parent 7d914bc147
commit c23bc7a5d3
4 changed files with 15 additions and 28 deletions

2
.gitignore vendored
View file

@ -1,3 +1,5 @@
*.zip
last_update.txt last_update.txt
*.log.json *.log.json

View file

@ -1,11 +1,13 @@
"""This module contains functions for checking if a message violates the moderation policy.""" """This module contains functions for checking if a message violates the moderation policy."""
import time import time
import difflib
import asyncio import asyncio
import aiocache import aiocache
import profanity_check import profanity_check
from typing import Union from typing import Union
from Levenshtein import distance
cache = aiocache.Cache(aiocache.SimpleMemoryCache) cache = aiocache.Cache(aiocache.SimpleMemoryCache)
@ -41,15 +43,13 @@ async def is_policy_violated(inp: Union[str, list]) -> bool:
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool: async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
"""Checks if the input violates the moderation policy using our own model.""" """Checks if the input violates the moderation policy using our own model."""
inp = input_to_text(inp) inp = input_to_text(inp).lower()
if profanity_check.predict([inp])[0]: if profanity_check.predict([inp])[0]:
return 'NovaAI\'s selfhosted moderation model detected unsuitable content.' return 'Sorry, our moderation AI has detected NSFW content in your message.'
return False return False
if __name__ == '__main__': if __name__ == '__main__':
for i in range(10): while True:
start = time.perf_counter() print(asyncio.run(is_policy_violated(input('-> '))))
print(asyncio.run(is_policy_violated('kill ms')))
print((time.perf_counter() - start) * 1000)

View file

@ -116,7 +116,7 @@ async def stream(
ssl=False, ssl=False,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
connect=0.5, connect=0.5,
total=float(os.getenv('TRANSFER_TIMEOUT', '120')) total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
), ),
) as response: ) as response:

View file

@ -43,7 +43,7 @@ async def test_server():
else: else:
return time.perf_counter() - request_start return time.perf_counter() - request_start
async def test_chat(model: str=MODEL, messages: List[dict]=None) -> dict: async def test_chat_non_stream(model: str=MODEL, messages: List[dict]=None) -> dict:
"""Tests an API api_endpoint.""" """Tests an API api_endpoint."""
json_data = { json_data = {
@ -107,21 +107,6 @@ async def test_models():
assert 'gpt-3.5-turbo' in all_models, 'The model gpt-3.5-turbo is not present in the models endpoint.' assert 'gpt-3.5-turbo' in all_models, 'The model gpt-3.5-turbo is not present in the models endpoint.'
return time.perf_counter() - request_start return time.perf_counter() - request_start
async def test_api_moderation() -> dict:
"""Tests the moderation endpoint."""
request_start = time.perf_counter()
async with httpx.AsyncClient() as client:
response = await client.post(
url=f'{api_endpoint}/moderations',
headers=HEADERS,
timeout=5,
json={'input': 'fuck you, die'}
)
assert response.json()['results'][0]['flagged'] == True, 'Profanity not detected'
return time.perf_counter() - request_start
# ========================================================================================== # ==========================================================================================
async def demo(): async def demo():
@ -137,16 +122,16 @@ async def demo():
else: else:
raise ConnectionError('API Server is not running.') raise ConnectionError('API Server is not running.')
print('[lightblue]Checking if the API works...') print('Checking non-streamed chat completions...')
print(await test_chat()) print(await test_chat_non_stream())
# print('[lightblue]Checking if SDXL image generation works...') # print('[lightblue]Checking if SDXL image generation works...')
# print(await test_sdxl()) # print(await test_sdxl())
print('[lightblue]Checking if the moderation endpoint works...') # print('[lightblue]Checking if the moderation endpoint works...')
print(await test_api_moderation()) # print(await test_api_moderation())
print('[lightblue]Checking the models endpoint...') print('Checking the models endpoint...')
print(await test_models()) print(await test_models())
except Exception as exc: except Exception as exc: