Compare commits

...

3 commits

Author SHA1 Message Date
nsde ef3a549030 Made everything 4x faster 2023-09-02 22:30:11 +02:00
nsde 735d0e025b Moderation caching 2023-09-02 22:26:43 +02:00
nsde 4256a5ca9d Added caching and custom endpoint for own moderation path 2023-09-02 22:09:57 +02:00
3 changed files with 35 additions and 52 deletions

View file

@ -2,9 +2,11 @@
import fastapi
import pydantic
import functools
from rich import print
from dotenv import load_dotenv
from json import JSONDecodeError
from bson.objectid import ObjectId
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
@ -15,6 +17,7 @@ from helpers import network
import core
import handler
import moderation
load_dotenv()
@ -65,3 +68,15 @@ async def root():
async def v1_handler(request: fastapi.Request):
res = await handler.handle(request)
return res
@functools.lru_cache()
@app.post('/moderate')
async def moderate(request: fastapi.Request):
try:
prompt = await request.json()
prompt = prompt['text']
except (KeyError, JSONDecodeError):
return fastapi.Response(status_code=400)
result = await moderation.is_policy_violated__own_model(prompt)
return result or ''

View file

@ -2,15 +2,13 @@
import time
import asyncio
import aiohttp
import aiocache
import profanity_check
import proxies
import provider_auth
import load_balancing
from typing import Union
cache = aiocache.Cache(aiocache.SimpleMemoryCache)
def input_to_text(inp: Union[str, list]) -> str:
"""Converts the input to a string."""
@ -28,60 +26,30 @@ def input_to_text(inp: Union[str, list]) -> str:
return text
async def is_policy_violated(inp: Union[str, list]) -> bool:
"""Checks if the input violates the moderation policy.
"""
### Check if a message violates the moderation policy.
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
or just a simple string.
# use aio cache to cache the result
inp = input_to_text(inp)
Returns True if the message violates the policy, False otherwise.
"""
text = input_to_text(inp)
return await is_policy_violated__own_model(text)
for _ in range(1):
req = await load_balancing.balance_organic_request(
{
'path': '/v1/moderations',
'payload': {'input': text}
}
)
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
try:
async with session.request(
method=req.get('method', 'POST'),
url=req['url'],
data=req.get('data'),
json=req.get('payload'),
headers=req.get('headers'),
cookies=req.get('cookies'),
ssl=False,
timeout=aiohttp.ClientTimeout(total=2),
) as res:
res.raise_for_status()
json_response = await res.json()
categories = json_response['results'][0]['category_scores']
if json_response['results'][0]['flagged']:
return max(categories, key=categories.get)
return False
except Exception as exc:
if '401' in str(exc):
await provider_auth.invalidate_key(req.get('provider_auth'))
print('[!] moderation error:', type(exc), exc)
continue
# utilize the cache
if await cache.exists(inp):
return await cache.get(inp)
else:
await cache.set(inp, await is_policy_violated__own_model(inp))
return await cache.get(inp)
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
"""Checks if the input violates the moderation policy using our own model."""
inp = input_to_text(inp)
if profanity_check.predict([inp])[0]:
return 'own model detected'
return 'NovaAI\'s selfhosted moderation model detected unsuitable content.'
return False
if __name__ == '__main__':
print(asyncio.run(is_policy_violated('kill ms')))
for i in range(10):
start = time.perf_counter()
print(asyncio.run(is_policy_violated('kill ms')))
print((time.perf_counter() - start) * 1000)

View file

@ -114,7 +114,7 @@ async def stream(
cookies=target_request.get('cookies'),
ssl=False,
timeout=aiohttp.ClientTimeout(
connect=2,
connect=0.5,
total=float(os.getenv('TRANSFER_TIMEOUT', '120'))
),
) as response: