diff --git a/api/moderation.py b/api/moderation.py index bed4c94..90ff384 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -3,6 +3,7 @@ import time import asyncio import aiohttp +import profanity_check import proxies import provider_auth @@ -10,14 +11,8 @@ import load_balancing from typing import Union -async def is_policy_violated(inp: Union[str, list]) -> bool: - """ - ### Check if a message violates the moderation policy. - You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter, - or just a simple string. - - Returns True if the message violates the policy, False otherwise. - """ +def input_to_text(inp: Union[str, list]) -> str: + """Converts the input to a string.""" text = inp @@ -30,7 +25,21 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: else: text = '\n'.join(inp) - for _ in range(5): + return text + +async def is_policy_violated(inp: Union[str, list]) -> bool: + """ + ### Check if a message violates the moderation policy. + You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter, + or just a simple string. + + Returns True if the message violates the policy, False otherwise. + """ + + text = input_to_text(inp) + return await is_policy_violated__own_model(text) + + for _ in range(1): req = await load_balancing.balance_organic_request( { 'path': '/v1/moderations', @@ -61,11 +70,18 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: return False except Exception as exc: - if '401' in str(exc): await provider_auth.invalidate_key(req.get('provider_auth')) print('[!] moderation error:', type(exc), exc) continue +async def is_policy_violated__own_model(inp: Union[str, list]) -> bool: + inp = input_to_text(inp) + + if profanity_check.predict([inp])[0]: + return 'own model detected' + + return False + if __name__ == '__main__': - print(asyncio.run(is_policy_violated('I wanna kill myself'))) + print(asyncio.run(is_policy_violated('kill ms'))) diff --git a/api/streaming.py b/api/streaming.py index 4f4fa5a..f2689fb 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -177,5 +177,7 @@ async def stream( model=model, ) + print(f'[+] {path} -> {model or "") + if __name__ == '__main__': asyncio.run(stream()) diff --git a/setup.md b/setup.md index 1a3c2ed..1b60c80 100644 --- a/setup.md +++ b/setup.md @@ -23,6 +23,17 @@ or pip install . ``` +*** + +Profanity checking requires: + +``` +pip install alt-profanity-check +# doesn't work? try +pip install git+https://github.com/dimitrismistriotis/alt-profanity-check.git +``` + + ## `.env` configuration Create a `.env` file, make sure not to reveal any of its contents to anyone, and fill in the required values in the format `KEY=VALUE`. Otherwise, the code won't run.