mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 16:43:58 +01:00
Added caching and custom endpoint for own moderation path
This commit is contained in:
parent
6bd5dc534c
commit
4256a5ca9d
15
api/main.py
15
api/main.py
|
@ -2,9 +2,11 @@
|
|||
|
||||
import fastapi
|
||||
import pydantic
|
||||
import functools
|
||||
|
||||
from rich import print
|
||||
from dotenv import load_dotenv
|
||||
from json import JSONDecodeError
|
||||
from bson.objectid import ObjectId
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
|
@ -15,6 +17,7 @@ from helpers import network
|
|||
|
||||
import core
|
||||
import handler
|
||||
import moderation
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -65,3 +68,15 @@ async def root():
|
|||
async def v1_handler(request: fastapi.Request):
|
||||
res = await handler.handle(request)
|
||||
return res
|
||||
|
||||
@functools.lru_cache()
|
||||
@app.post('/moderate')
|
||||
async def moderate(request: fastapi.Request):
|
||||
try:
|
||||
prompt = await request.json()
|
||||
prompt = prompt['text']
|
||||
except (KeyError, JSONDecodeError):
|
||||
return fastapi.Response(status_code=400)
|
||||
|
||||
result = await moderation.is_policy_violated__own_model(prompt)
|
||||
return result or ''
|
||||
|
|
|
@ -1,14 +1,9 @@
|
|||
"""This module contains functions for checking if a message violates the moderation policy."""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import functools
|
||||
import profanity_check
|
||||
|
||||
import proxies
|
||||
import provider_auth
|
||||
import load_balancing
|
||||
|
||||
from typing import Union
|
||||
|
||||
def input_to_text(inp: Union[str, list]) -> str:
|
||||
|
@ -27,59 +22,19 @@ def input_to_text(inp: Union[str, list]) -> str:
|
|||
|
||||
return text
|
||||
|
||||
@functools.lru_cache()
|
||||
async def is_policy_violated(inp: Union[str, list]) -> bool:
|
||||
"""Checks if the input violates the moderation policy.
|
||||
"""
|
||||
### Check if a message violates the moderation policy.
|
||||
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
|
||||
or just a simple string.
|
||||
|
||||
Returns True if the message violates the policy, False otherwise.
|
||||
"""
|
||||
|
||||
text = input_to_text(inp)
|
||||
return await is_policy_violated__own_model(text)
|
||||
|
||||
for _ in range(1):
|
||||
req = await load_balancing.balance_organic_request(
|
||||
{
|
||||
'path': '/v1/moderations',
|
||||
'payload': {'input': text}
|
||||
}
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
|
||||
try:
|
||||
async with session.request(
|
||||
method=req.get('method', 'POST'),
|
||||
url=req['url'],
|
||||
data=req.get('data'),
|
||||
json=req.get('payload'),
|
||||
headers=req.get('headers'),
|
||||
cookies=req.get('cookies'),
|
||||
ssl=False,
|
||||
timeout=aiohttp.ClientTimeout(total=2),
|
||||
) as res:
|
||||
res.raise_for_status()
|
||||
json_response = await res.json()
|
||||
|
||||
categories = json_response['results'][0]['category_scores']
|
||||
|
||||
if json_response['results'][0]['flagged']:
|
||||
return max(categories, key=categories.get)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
if '401' in str(exc):
|
||||
await provider_auth.invalidate_key(req.get('provider_auth'))
|
||||
print('[!] moderation error:', type(exc), exc)
|
||||
continue
|
||||
return await is_policy_violated__own_model(inp)
|
||||
|
||||
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
|
||||
"""Checks if the input violates the moderation policy using our own model."""
|
||||
|
||||
inp = input_to_text(inp)
|
||||
|
||||
if profanity_check.predict([inp])[0]:
|
||||
return 'own model detected'
|
||||
return 'NovaAI\'s selfhosted moderation model detected unsuitable content.'
|
||||
|
||||
return False
|
||||
|
||||
|
|
Loading…
Reference in a new issue