diff --git a/api/main.py b/api/main.py index 5e9bf09..36c4e55 100644 --- a/api/main.py +++ b/api/main.py @@ -2,9 +2,11 @@ import fastapi import pydantic +import functools from rich import print from dotenv import load_dotenv +from json import JSONDecodeError from bson.objectid import ObjectId from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware @@ -15,6 +17,7 @@ from helpers import network import core import handler +import moderation load_dotenv() @@ -65,3 +68,15 @@ async def root(): async def v1_handler(request: fastapi.Request): res = await handler.handle(request) return res + +@functools.lru_cache() +@app.post('/moderate') +async def moderate(request: fastapi.Request): + try: + prompt = await request.json() + prompt = prompt['text'] + except (KeyError, JSONDecodeError): + return fastapi.Response(status_code=400) + + result = await moderation.is_policy_violated__own_model(prompt) + return result or '' diff --git a/api/moderation.py b/api/moderation.py index 90ff384..d3e8d99 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -1,14 +1,9 @@ """This module contains functions for checking if a message violates the moderation policy.""" -import time import asyncio -import aiohttp +import functools import profanity_check -import proxies -import provider_auth -import load_balancing - from typing import Union def input_to_text(inp: Union[str, list]) -> str: @@ -27,59 +22,19 @@ def input_to_text(inp: Union[str, list]) -> str: return text +@functools.lru_cache() async def is_policy_violated(inp: Union[str, list]) -> bool: + """Checks if the input violates the moderation policy. """ - ### Check if a message violates the moderation policy. - You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter, - or just a simple string. - - Returns True if the message violates the policy, False otherwise. - """ - - text = input_to_text(inp) - return await is_policy_violated__own_model(text) - - for _ in range(1): - req = await load_balancing.balance_organic_request( - { - 'path': '/v1/moderations', - 'payload': {'input': text} - } - ) - - async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session: - try: - async with session.request( - method=req.get('method', 'POST'), - url=req['url'], - data=req.get('data'), - json=req.get('payload'), - headers=req.get('headers'), - cookies=req.get('cookies'), - ssl=False, - timeout=aiohttp.ClientTimeout(total=2), - ) as res: - res.raise_for_status() - json_response = await res.json() - - categories = json_response['results'][0]['category_scores'] - - if json_response['results'][0]['flagged']: - return max(categories, key=categories.get) - - return False - - except Exception as exc: - if '401' in str(exc): - await provider_auth.invalidate_key(req.get('provider_auth')) - print('[!] moderation error:', type(exc), exc) - continue + return await is_policy_violated__own_model(inp) async def is_policy_violated__own_model(inp: Union[str, list]) -> bool: + """Checks if the input violates the moderation policy using our own model.""" + inp = input_to_text(inp) if profanity_check.predict([inp])[0]: - return 'own model detected' + return 'NovaAI\'s selfhosted moderation model detected unsuitable content.' return False