diff --git a/api/moderation.py b/api/moderation.py index d3e8d99..c868af9 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -1,11 +1,14 @@ """This module contains functions for checking if a message violates the moderation policy.""" +import time import asyncio -import functools +import aiocache import profanity_check from typing import Union +cache = aiocache.Cache(aiocache.SimpleMemoryCache) + def input_to_text(inp: Union[str, list]) -> str: """Converts the input to a string.""" @@ -22,11 +25,18 @@ def input_to_text(inp: Union[str, list]) -> str: return text -@functools.lru_cache() async def is_policy_violated(inp: Union[str, list]) -> bool: """Checks if the input violates the moderation policy. """ - return await is_policy_violated__own_model(inp) + # use aio cache to cache the result + inp = input_to_text(inp) + + # utilize the cache + if await cache.exists(inp): + return await cache.get(inp) + else: + await cache.set(inp, await is_policy_violated__own_model(inp)) + return await cache.get(inp) async def is_policy_violated__own_model(inp: Union[str, list]) -> bool: """Checks if the input violates the moderation policy using our own model.""" @@ -39,4 +49,7 @@ async def is_policy_violated__own_model(inp: Union[str, list]) -> bool: return False if __name__ == '__main__': - print(asyncio.run(is_policy_violated('kill ms'))) + for i in range(10): + start = time.perf_counter() + print(asyncio.run(is_policy_violated('kill ms'))) + print((time.perf_counter() - start) * 1000)