nova-api/api/moderation.py

79 lines
2.5 KiB
Python
Raw Normal View History

2023-08-12 17:49:31 +02:00
"""This module contains functions for checking if a message violates the moderation policy."""
import time
2023-08-04 03:30:56 +02:00
import asyncio
2023-08-06 21:42:07 +02:00
import aiohttp
2023-08-06 21:42:07 +02:00
import proxies
import provider_auth
2023-08-06 21:42:07 +02:00
import load_balancing
2023-08-04 03:30:56 +02:00
2023-08-12 17:49:31 +02:00
from typing import Union
async def is_policy_violated(inp: Union[str, list]) -> bool:
2023-08-13 11:07:52 +02:00
"""
### Check if a message violates the moderation policy.
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
or just a simple string.
Returns True if the message violates the policy, False otherwise.
"""
2023-08-12 17:49:31 +02:00
2023-08-06 21:42:07 +02:00
text = inp
2023-08-04 03:30:56 +02:00
2023-08-06 21:42:07 +02:00
if isinstance(inp, list):
text = ''
if isinstance(inp[0], dict):
for msg in inp:
text += msg['content'] + '\n'
2023-08-04 03:30:56 +02:00
2023-08-06 21:42:07 +02:00
else:
text = '\n'.join(inp)
print(f'[i] checking moderation for {text}')
2023-08-06 21:42:07 +02:00
for _ in range(3):
req = await load_balancing.balance_organic_request(
{
'path': '/v1/moderations',
'payload': {'input': text}
}
)
print(f'[i] moderation request sent to {req["url"]}')
2023-08-06 21:42:07 +02:00
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
2023-08-06 21:42:07 +02:00
try:
start = time.perf_counter()
2023-08-06 21:42:07 +02:00
async with session.request(
method=req.get('method', 'POST'),
url=req['url'],
data=req.get('data'),
json=req.get('payload'),
headers=req.get('headers'),
cookies=req.get('cookies'),
ssl=False,
timeout=aiohttp.ClientTimeout(total=2),
2023-08-06 21:42:07 +02:00
) as res:
res.raise_for_status()
json_response = await res.json()
print(json_response)
categories = json_response['results'][0]['category_scores']
2023-08-06 21:42:07 +02:00
print(f'[i] moderation check took {time.perf_counter() - start:.2f}s')
if json_response['results'][0]['flagged']:
return max(categories, key=categories.get)
return False
2023-08-06 21:42:07 +02:00
except Exception as exc:
if '401' in str(exc):
await provider_auth.invalidate_key(req.get('provider_auth'))
2023-08-06 21:42:07 +02:00
print('[!] moderation error:', type(exc), exc)
continue
2023-08-04 03:30:56 +02:00
if __name__ == '__main__':
print(asyncio.run(is_policy_violated('I wanna kill myself')))