diff --git a/api/load_balancing.py b/api/load_balancing.py index 8673d6c..3fb6f54 100644 --- a/api/load_balancing.py +++ b/api/load_balancing.py @@ -24,7 +24,7 @@ async def balance_chat_request(payload: dict) -> dict: providers_available.append(provider_module) if not providers_available: - raise NotImplementedError('This model does not exist.') + raise NotImplementedError(f'The model "{payload["model"]}" is not available. MODEl_UNAVAILABLE') provider = random.choice(providers_available) target = provider.chat_completion(**payload) diff --git a/api/moderation.py b/api/moderation.py index 38b2d8b..251dacb 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -5,7 +5,7 @@ import proxies import provider_auth import load_balancing -async def is_safe(inp) -> bool: +async def is_policy_violated(inp) -> bool: text = inp if isinstance(inp, list): @@ -35,17 +35,21 @@ async def is_safe(inp) -> bool: headers=req.get('headers'), cookies=req.get('cookies'), ssl=False, - timeout=aiohttp.ClientTimeout(total=5), + timeout=aiohttp.ClientTimeout(total=2), ) as res: res.raise_for_status() json_response = await res.json() + categories = json_response['results'][0]['category_scores'] - return not json_response['results'][0]['flagged'] + if json_response['results'][0]['flagged']: + return max(categories, key=categories.get) + + return False except Exception as exc: - await provider_auth.invalidate_key(req.get('provider_auth')) + # await provider_auth.invalidate_key(req.get('provider_auth')) print('[!] moderation error:', type(exc), exc) continue if __name__ == '__main__': - print(asyncio.run(is_safe('I wanna kill myself'))) + print(asyncio.run(is_policy_violated('I wanna kill myself'))) diff --git a/api/transfer.py b/api/transfer.py index 13f80ac..153d866 100644 --- a/api/transfer.py +++ b/api/transfer.py @@ -69,24 +69,27 @@ async def handle(incoming_request): costs = credits_config['costs'] cost = costs['other'] - is_safe = True + policy_violation = False if 'chat/completions' in path: for model_name, model_cost in costs['chat-models'].items(): if model_name in payload['model']: cost = model_cost - is_safe = await moderation.is_safe(payload['messages']) + policy_violation = await moderation.is_policy_violated(payload['messages']) + + elif '/moderations' in path: + pass else: inp = payload.get('input', payload.get('prompt')) if inp: if len(inp) > 2 and not inp.isnumeric(): - is_safe = await moderation.is_safe(inp) + policy_violation = await moderation.is_policy_violated(inp) - if not is_safe and not '/moderations' in path: - error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.') + if policy_violation: + error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') return error role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)