2023-08-12 17:49:31 +02:00
|
|
|
"""This module contains functions for checking if a message violates the moderation policy."""
|
|
|
|
|
2023-08-23 23:27:09 +02:00
|
|
|
import time
|
2023-08-04 03:30:56 +02:00
|
|
|
import asyncio
|
2023-08-06 21:42:07 +02:00
|
|
|
import aiohttp
|
2023-08-30 22:13:23 +02:00
|
|
|
import profanity_check
|
2023-08-07 23:28:24 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
import proxies
|
2023-08-07 23:28:24 +02:00
|
|
|
import provider_auth
|
2023-08-06 21:42:07 +02:00
|
|
|
import load_balancing
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-12 17:49:31 +02:00
|
|
|
from typing import Union
|
|
|
|
|
2023-08-30 22:13:23 +02:00
|
|
|
def input_to_text(inp: Union[str, list]) -> str:
|
|
|
|
"""Converts the input to a string."""
|
2023-08-12 17:49:31 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
text = inp
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
if isinstance(inp, list):
|
|
|
|
text = ''
|
|
|
|
if isinstance(inp[0], dict):
|
|
|
|
for msg in inp:
|
|
|
|
text += msg['content'] + '\n'
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
else:
|
|
|
|
text = '\n'.join(inp)
|
|
|
|
|
2023-08-30 22:13:23 +02:00
|
|
|
return text
|
|
|
|
|
|
|
|
async def is_policy_violated(inp: Union[str, list]) -> bool:
|
|
|
|
"""
|
|
|
|
### Check if a message violates the moderation policy.
|
|
|
|
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
|
|
|
|
or just a simple string.
|
|
|
|
|
|
|
|
Returns True if the message violates the policy, False otherwise.
|
|
|
|
"""
|
|
|
|
|
|
|
|
text = input_to_text(inp)
|
|
|
|
return await is_policy_violated__own_model(text)
|
|
|
|
|
|
|
|
for _ in range(1):
|
2023-08-06 21:42:07 +02:00
|
|
|
req = await load_balancing.balance_organic_request(
|
|
|
|
{
|
|
|
|
'path': '/v1/moderations',
|
|
|
|
'payload': {'input': text}
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2023-08-13 00:59:54 +02:00
|
|
|
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
|
2023-08-06 21:42:07 +02:00
|
|
|
try:
|
|
|
|
async with session.request(
|
|
|
|
method=req.get('method', 'POST'),
|
|
|
|
url=req['url'],
|
|
|
|
data=req.get('data'),
|
|
|
|
json=req.get('payload'),
|
|
|
|
headers=req.get('headers'),
|
|
|
|
cookies=req.get('cookies'),
|
|
|
|
ssl=False,
|
2023-08-28 13:47:13 +02:00
|
|
|
timeout=aiohttp.ClientTimeout(total=2),
|
2023-08-06 21:42:07 +02:00
|
|
|
) as res:
|
|
|
|
res.raise_for_status()
|
|
|
|
json_response = await res.json()
|
2023-08-23 23:27:09 +02:00
|
|
|
|
2023-08-08 01:04:35 +02:00
|
|
|
categories = json_response['results'][0]['category_scores']
|
2023-08-06 21:42:07 +02:00
|
|
|
|
2023-08-08 01:04:35 +02:00
|
|
|
if json_response['results'][0]['flagged']:
|
|
|
|
return max(categories, key=categories.get)
|
|
|
|
|
|
|
|
return False
|
2023-08-07 23:28:24 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
except Exception as exc:
|
2023-08-09 11:15:49 +02:00
|
|
|
if '401' in str(exc):
|
|
|
|
await provider_auth.invalidate_key(req.get('provider_auth'))
|
2023-08-06 21:42:07 +02:00
|
|
|
print('[!] moderation error:', type(exc), exc)
|
|
|
|
continue
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-30 22:13:23 +02:00
|
|
|
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
|
|
|
|
inp = input_to_text(inp)
|
|
|
|
|
|
|
|
if profanity_check.predict([inp])[0]:
|
|
|
|
return 'own model detected'
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
2023-08-04 03:30:56 +02:00
|
|
|
if __name__ == '__main__':
|
2023-08-30 22:13:23 +02:00
|
|
|
print(asyncio.run(is_policy_violated('kill ms')))
|