nova-api/api/moderation.py

56 lines
1.4 KiB
Python
Raw Normal View History

2023-08-12 17:49:31 +02:00
"""This module contains functions for checking if a message violates the moderation policy."""
2023-09-02 22:26:43 +02:00
import time
2023-09-10 16:22:46 +02:00
import difflib
2023-08-04 03:30:56 +02:00
import asyncio
2023-09-02 22:26:43 +02:00
import aiocache
2023-08-30 22:13:23 +02:00
import profanity_check
2023-08-12 17:49:31 +02:00
from typing import Union
2023-09-10 16:22:46 +02:00
from Levenshtein import distance
2023-08-12 17:49:31 +02:00
2023-09-02 22:26:43 +02:00
cache = aiocache.Cache(aiocache.SimpleMemoryCache)
2023-08-30 22:13:23 +02:00
def input_to_text(inp: Union[str, list]) -> str:
"""Converts the input to a string."""
2023-08-12 17:49:31 +02:00
2023-08-06 21:42:07 +02:00
text = inp
2023-08-04 03:30:56 +02:00
2023-08-06 21:42:07 +02:00
if isinstance(inp, list):
text = ''
if isinstance(inp[0], dict):
for msg in inp:
text += msg['content'] + '\n'
2023-08-04 03:30:56 +02:00
2023-08-06 21:42:07 +02:00
else:
text = '\n'.join(inp)
2023-08-30 22:13:23 +02:00
return text
async def is_policy_violated(inp: Union[str, list]) -> bool:
"""Checks if the input violates the moderation policy.
2023-08-30 22:13:23 +02:00
"""
2023-09-02 22:26:43 +02:00
# use aio cache to cache the result
inp = input_to_text(inp)
# utilize the cache
if await cache.exists(inp):
return await cache.get(inp)
else:
await cache.set(inp, await is_policy_violated__own_model(inp))
return await cache.get(inp)
2023-08-04 03:30:56 +02:00
2023-08-30 22:13:23 +02:00
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
"""Checks if the input violates the moderation policy using our own model."""
2023-09-10 16:22:46 +02:00
inp = input_to_text(inp).lower()
2023-08-30 22:13:23 +02:00
if profanity_check.predict([inp])[0]:
2023-09-10 16:22:46 +02:00
return 'Sorry, our moderation AI has detected NSFW content in your message.'
2023-08-30 22:13:23 +02:00
return False
2023-08-04 03:30:56 +02:00
if __name__ == '__main__':
2023-09-10 16:22:46 +02:00
while True:
print(asyncio.run(is_policy_violated(input('-> '))))