Fixed some several issues with moderation, models etc.

This commit is contained in:
nsde 2023-08-08 01:04:35 +02:00
parent 1a3e275a1c
commit ce24c3a5a2
3 changed files with 18 additions and 11 deletions

View file

@ -24,7 +24,7 @@ async def balance_chat_request(payload: dict) -> dict:
providers_available.append(provider_module) providers_available.append(provider_module)
if not providers_available: if not providers_available:
raise NotImplementedError('This model does not exist.') raise NotImplementedError(f'The model "{payload["model"]}" is not available. MODEl_UNAVAILABLE')
provider = random.choice(providers_available) provider = random.choice(providers_available)
target = provider.chat_completion(**payload) target = provider.chat_completion(**payload)

View file

@ -5,7 +5,7 @@ import proxies
import provider_auth import provider_auth
import load_balancing import load_balancing
async def is_safe(inp) -> bool: async def is_policy_violated(inp) -> bool:
text = inp text = inp
if isinstance(inp, list): if isinstance(inp, list):
@ -35,17 +35,21 @@ async def is_safe(inp) -> bool:
headers=req.get('headers'), headers=req.get('headers'),
cookies=req.get('cookies'), cookies=req.get('cookies'),
ssl=False, ssl=False,
timeout=aiohttp.ClientTimeout(total=5), timeout=aiohttp.ClientTimeout(total=2),
) as res: ) as res:
res.raise_for_status() res.raise_for_status()
json_response = await res.json() json_response = await res.json()
categories = json_response['results'][0]['category_scores']
return not json_response['results'][0]['flagged'] if json_response['results'][0]['flagged']:
return max(categories, key=categories.get)
return False
except Exception as exc: except Exception as exc:
await provider_auth.invalidate_key(req.get('provider_auth')) # await provider_auth.invalidate_key(req.get('provider_auth'))
print('[!] moderation error:', type(exc), exc) print('[!] moderation error:', type(exc), exc)
continue continue
if __name__ == '__main__': if __name__ == '__main__':
print(asyncio.run(is_safe('I wanna kill myself'))) print(asyncio.run(is_policy_violated('I wanna kill myself')))

View file

@ -69,24 +69,27 @@ async def handle(incoming_request):
costs = credits_config['costs'] costs = credits_config['costs']
cost = costs['other'] cost = costs['other']
is_safe = True policy_violation = False
if 'chat/completions' in path: if 'chat/completions' in path:
for model_name, model_cost in costs['chat-models'].items(): for model_name, model_cost in costs['chat-models'].items():
if model_name in payload['model']: if model_name in payload['model']:
cost = model_cost cost = model_cost
is_safe = await moderation.is_safe(payload['messages']) policy_violation = await moderation.is_policy_violated(payload['messages'])
elif '/moderations' in path:
pass
else: else:
inp = payload.get('input', payload.get('prompt')) inp = payload.get('input', payload.get('prompt'))
if inp: if inp:
if len(inp) > 2 and not inp.isnumeric(): if len(inp) > 2 and not inp.isnumeric():
is_safe = await moderation.is_safe(inp) policy_violation = await moderation.is_policy_violated(inp)
if not is_safe and not '/moderations' in path: if policy_violation:
error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.') error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
return error return error
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)