Uses own AI for moderation

This commit is contained in:
nsde 2023-08-30 22:13:23 +02:00
parent d4237dd65e
commit 15f816fd1d
3 changed files with 40 additions and 11 deletions

View file

@ -3,6 +3,7 @@
import time
import asyncio
import aiohttp
import profanity_check
import proxies
import provider_auth
@ -10,14 +11,8 @@ import load_balancing
from typing import Union
async def is_policy_violated(inp: Union[str, list]) -> bool:
"""
### Check if a message violates the moderation policy.
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
or just a simple string.
Returns True if the message violates the policy, False otherwise.
"""
def input_to_text(inp: Union[str, list]) -> str:
"""Converts the input to a string."""
text = inp
@ -30,7 +25,21 @@ async def is_policy_violated(inp: Union[str, list]) -> bool:
else:
text = '\n'.join(inp)
for _ in range(5):
return text
async def is_policy_violated(inp: Union[str, list]) -> bool:
"""
### Check if a message violates the moderation policy.
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
or just a simple string.
Returns True if the message violates the policy, False otherwise.
"""
text = input_to_text(inp)
return await is_policy_violated__own_model(text)
for _ in range(1):
req = await load_balancing.balance_organic_request(
{
'path': '/v1/moderations',
@ -61,11 +70,18 @@ async def is_policy_violated(inp: Union[str, list]) -> bool:
return False
except Exception as exc:
if '401' in str(exc):
await provider_auth.invalidate_key(req.get('provider_auth'))
print('[!] moderation error:', type(exc), exc)
continue
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
inp = input_to_text(inp)
if profanity_check.predict([inp])[0]:
return 'own model detected'
return False
if __name__ == '__main__':
print(asyncio.run(is_policy_violated('I wanna kill myself')))
print(asyncio.run(is_policy_violated('kill ms')))

View file

@ -177,5 +177,7 @@ async def stream(
model=model,
)
print(f'[+] {path} -> {model or "")
if __name__ == '__main__':
asyncio.run(stream())

View file

@ -23,6 +23,17 @@ or
pip install .
```
***
Profanity checking requires:
```
pip install alt-profanity-check
# doesn't work? try
pip install git+https://github.com/dimitrismistriotis/alt-profanity-check.git
```
## `.env` configuration
Create a `.env` file, make sure not to reveal any of its contents to anyone, and fill in the required values in the format `KEY=VALUE`. Otherwise, the code won't run.