2023-08-12 17:49:31 +02:00
|
|
|
"""This module contains functions for checking if a message violates the moderation policy."""
|
|
|
|
|
2023-08-04 03:30:56 +02:00
|
|
|
import asyncio
|
2023-09-02 22:09:57 +02:00
|
|
|
import functools
|
2023-08-30 22:13:23 +02:00
|
|
|
import profanity_check
|
2023-08-07 23:28:24 +02:00
|
|
|
|
2023-08-12 17:49:31 +02:00
|
|
|
from typing import Union
|
|
|
|
|
2023-08-30 22:13:23 +02:00
|
|
|
def input_to_text(inp: Union[str, list]) -> str:
|
|
|
|
"""Converts the input to a string."""
|
2023-08-12 17:49:31 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
text = inp
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
if isinstance(inp, list):
|
|
|
|
text = ''
|
|
|
|
if isinstance(inp[0], dict):
|
|
|
|
for msg in inp:
|
|
|
|
text += msg['content'] + '\n'
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
else:
|
|
|
|
text = '\n'.join(inp)
|
|
|
|
|
2023-08-30 22:13:23 +02:00
|
|
|
return text
|
|
|
|
|
2023-09-02 22:09:57 +02:00
|
|
|
@functools.lru_cache()
|
2023-08-30 22:13:23 +02:00
|
|
|
async def is_policy_violated(inp: Union[str, list]) -> bool:
|
2023-09-02 22:09:57 +02:00
|
|
|
"""Checks if the input violates the moderation policy.
|
2023-08-30 22:13:23 +02:00
|
|
|
"""
|
2023-09-02 22:09:57 +02:00
|
|
|
return await is_policy_violated__own_model(inp)
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-30 22:13:23 +02:00
|
|
|
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
|
2023-09-02 22:09:57 +02:00
|
|
|
"""Checks if the input violates the moderation policy using our own model."""
|
|
|
|
|
2023-08-30 22:13:23 +02:00
|
|
|
inp = input_to_text(inp)
|
|
|
|
|
|
|
|
if profanity_check.predict([inp])[0]:
|
2023-09-02 22:09:57 +02:00
|
|
|
return 'NovaAI\'s selfhosted moderation model detected unsuitable content.'
|
2023-08-30 22:13:23 +02:00
|
|
|
|
|
|
|
return False
|
|
|
|
|
2023-08-04 03:30:56 +02:00
|
|
|
if __name__ == '__main__':
|
2023-08-30 22:13:23 +02:00
|
|
|
print(asyncio.run(is_policy_violated('kill ms')))
|