Fixed some several issues with moderation, models etc.

This commit is contained in:
nsde 2023-08-08 01:04:35 +02:00
parent 1a3e275a1c
commit ce24c3a5a2
3 changed files with 18 additions and 11 deletions

View file

@ -24,7 +24,7 @@ async def balance_chat_request(payload: dict) -> dict:
providers_available.append(provider_module)
if not providers_available:
raise NotImplementedError('This model does not exist.')
raise NotImplementedError(f'The model "{payload["model"]}" is not available. MODEl_UNAVAILABLE')
provider = random.choice(providers_available)
target = provider.chat_completion(**payload)

View file

@ -5,7 +5,7 @@ import proxies
import provider_auth
import load_balancing
async def is_safe(inp) -> bool:
async def is_policy_violated(inp) -> bool:
text = inp
if isinstance(inp, list):
@ -35,17 +35,21 @@ async def is_safe(inp) -> bool:
headers=req.get('headers'),
cookies=req.get('cookies'),
ssl=False,
timeout=aiohttp.ClientTimeout(total=5),
timeout=aiohttp.ClientTimeout(total=2),
) as res:
res.raise_for_status()
json_response = await res.json()
categories = json_response['results'][0]['category_scores']
return not json_response['results'][0]['flagged']
if json_response['results'][0]['flagged']:
return max(categories, key=categories.get)
return False
except Exception as exc:
await provider_auth.invalidate_key(req.get('provider_auth'))
# await provider_auth.invalidate_key(req.get('provider_auth'))
print('[!] moderation error:', type(exc), exc)
continue
if __name__ == '__main__':
print(asyncio.run(is_safe('I wanna kill myself')))
print(asyncio.run(is_policy_violated('I wanna kill myself')))

View file

@ -69,24 +69,27 @@ async def handle(incoming_request):
costs = credits_config['costs']
cost = costs['other']
is_safe = True
policy_violation = False
if 'chat/completions' in path:
for model_name, model_cost in costs['chat-models'].items():
if model_name in payload['model']:
cost = model_cost
is_safe = await moderation.is_safe(payload['messages'])
policy_violation = await moderation.is_policy_violated(payload['messages'])
elif '/moderations' in path:
pass
else:
inp = payload.get('input', payload.get('prompt'))
if inp:
if len(inp) > 2 and not inp.isnumeric():
is_safe = await moderation.is_safe(inp)
policy_violation = await moderation.is_policy_violated(inp)
if not is_safe and not '/moderations' in path:
error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.')
if policy_violation:
error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
return error
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)