Moderation caching

This commit is contained in:
nsde 2023-09-02 22:26:43 +02:00
parent 4256a5ca9d
commit 735d0e025b

View file

@ -1,11 +1,14 @@
"""This module contains functions for checking if a message violates the moderation policy.""" """This module contains functions for checking if a message violates the moderation policy."""
import time
import asyncio import asyncio
import functools import aiocache
import profanity_check import profanity_check
from typing import Union from typing import Union
cache = aiocache.Cache(aiocache.SimpleMemoryCache)
def input_to_text(inp: Union[str, list]) -> str: def input_to_text(inp: Union[str, list]) -> str:
"""Converts the input to a string.""" """Converts the input to a string."""
@ -22,11 +25,18 @@ def input_to_text(inp: Union[str, list]) -> str:
return text return text
@functools.lru_cache()
async def is_policy_violated(inp: Union[str, list]) -> bool: async def is_policy_violated(inp: Union[str, list]) -> bool:
"""Checks if the input violates the moderation policy. """Checks if the input violates the moderation policy.
""" """
return await is_policy_violated__own_model(inp) # use aio cache to cache the result
inp = input_to_text(inp)
# utilize the cache
if await cache.exists(inp):
return await cache.get(inp)
else:
await cache.set(inp, await is_policy_violated__own_model(inp))
return await cache.get(inp)
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool: async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
"""Checks if the input violates the moderation policy using our own model.""" """Checks if the input violates the moderation policy using our own model."""
@ -39,4 +49,7 @@ async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
return False return False
if __name__ == '__main__': if __name__ == '__main__':
print(asyncio.run(is_policy_violated('kill ms'))) for i in range(10):
start = time.perf_counter()
print(asyncio.run(is_policy_violated('kill ms')))
print((time.perf_counter() - start) * 1000)