mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 22:23:56 +01:00
Compare commits
3 commits
6bd5dc534c
...
ef3a549030
Author | SHA1 | Date | |
---|---|---|---|
ef3a549030 | |||
735d0e025b | |||
4256a5ca9d |
15
api/main.py
15
api/main.py
|
@ -2,9 +2,11 @@
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import pydantic
|
import pydantic
|
||||||
|
import functools
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from json import JSONDecodeError
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
from slowapi.middleware import SlowAPIMiddleware
|
from slowapi.middleware import SlowAPIMiddleware
|
||||||
|
@ -15,6 +17,7 @@ from helpers import network
|
||||||
|
|
||||||
import core
|
import core
|
||||||
import handler
|
import handler
|
||||||
|
import moderation
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -65,3 +68,15 @@ async def root():
|
||||||
async def v1_handler(request: fastapi.Request):
|
async def v1_handler(request: fastapi.Request):
|
||||||
res = await handler.handle(request)
|
res = await handler.handle(request)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
@app.post('/moderate')
|
||||||
|
async def moderate(request: fastapi.Request):
|
||||||
|
try:
|
||||||
|
prompt = await request.json()
|
||||||
|
prompt = prompt['text']
|
||||||
|
except (KeyError, JSONDecodeError):
|
||||||
|
return fastapi.Response(status_code=400)
|
||||||
|
|
||||||
|
result = await moderation.is_policy_violated__own_model(prompt)
|
||||||
|
return result or ''
|
||||||
|
|
|
@ -2,15 +2,13 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiocache
|
||||||
import profanity_check
|
import profanity_check
|
||||||
|
|
||||||
import proxies
|
|
||||||
import provider_auth
|
|
||||||
import load_balancing
|
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
cache = aiocache.Cache(aiocache.SimpleMemoryCache)
|
||||||
|
|
||||||
def input_to_text(inp: Union[str, list]) -> str:
|
def input_to_text(inp: Union[str, list]) -> str:
|
||||||
"""Converts the input to a string."""
|
"""Converts the input to a string."""
|
||||||
|
|
||||||
|
@ -28,60 +26,30 @@ def input_to_text(inp: Union[str, list]) -> str:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
async def is_policy_violated(inp: Union[str, list]) -> bool:
|
async def is_policy_violated(inp: Union[str, list]) -> bool:
|
||||||
|
"""Checks if the input violates the moderation policy.
|
||||||
"""
|
"""
|
||||||
### Check if a message violates the moderation policy.
|
# use aio cache to cache the result
|
||||||
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
|
inp = input_to_text(inp)
|
||||||
or just a simple string.
|
|
||||||
|
|
||||||
Returns True if the message violates the policy, False otherwise.
|
# utilize the cache
|
||||||
"""
|
if await cache.exists(inp):
|
||||||
|
return await cache.get(inp)
|
||||||
text = input_to_text(inp)
|
else:
|
||||||
return await is_policy_violated__own_model(text)
|
await cache.set(inp, await is_policy_violated__own_model(inp))
|
||||||
|
return await cache.get(inp)
|
||||||
for _ in range(1):
|
|
||||||
req = await load_balancing.balance_organic_request(
|
|
||||||
{
|
|
||||||
'path': '/v1/moderations',
|
|
||||||
'payload': {'input': text}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
|
|
||||||
try:
|
|
||||||
async with session.request(
|
|
||||||
method=req.get('method', 'POST'),
|
|
||||||
url=req['url'],
|
|
||||||
data=req.get('data'),
|
|
||||||
json=req.get('payload'),
|
|
||||||
headers=req.get('headers'),
|
|
||||||
cookies=req.get('cookies'),
|
|
||||||
ssl=False,
|
|
||||||
timeout=aiohttp.ClientTimeout(total=2),
|
|
||||||
) as res:
|
|
||||||
res.raise_for_status()
|
|
||||||
json_response = await res.json()
|
|
||||||
|
|
||||||
categories = json_response['results'][0]['category_scores']
|
|
||||||
|
|
||||||
if json_response['results'][0]['flagged']:
|
|
||||||
return max(categories, key=categories.get)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
if '401' in str(exc):
|
|
||||||
await provider_auth.invalidate_key(req.get('provider_auth'))
|
|
||||||
print('[!] moderation error:', type(exc), exc)
|
|
||||||
continue
|
|
||||||
|
|
||||||
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
|
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
|
||||||
|
"""Checks if the input violates the moderation policy using our own model."""
|
||||||
|
|
||||||
inp = input_to_text(inp)
|
inp = input_to_text(inp)
|
||||||
|
|
||||||
if profanity_check.predict([inp])[0]:
|
if profanity_check.predict([inp])[0]:
|
||||||
return 'own model detected'
|
return 'NovaAI\'s selfhosted moderation model detected unsuitable content.'
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
for i in range(10):
|
||||||
|
start = time.perf_counter()
|
||||||
print(asyncio.run(is_policy_violated('kill ms')))
|
print(asyncio.run(is_policy_violated('kill ms')))
|
||||||
|
print((time.perf_counter() - start) * 1000)
|
||||||
|
|
|
@ -114,7 +114,7 @@ async def stream(
|
||||||
cookies=target_request.get('cookies'),
|
cookies=target_request.get('cookies'),
|
||||||
ssl=False,
|
ssl=False,
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(
|
||||||
connect=2,
|
connect=0.5,
|
||||||
total=float(os.getenv('TRANSFER_TIMEOUT', '120'))
|
total=float(os.getenv('TRANSFER_TIMEOUT', '120'))
|
||||||
),
|
),
|
||||||
) as response:
|
) as response:
|
||||||
|
|
Loading…
Reference in a new issue