mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 14:43:58 +01:00
Fixed some several issues with moderation, models etc.
This commit is contained in:
parent
1a3e275a1c
commit
ce24c3a5a2
|
@ -24,7 +24,7 @@ async def balance_chat_request(payload: dict) -> dict:
|
|||
providers_available.append(provider_module)
|
||||
|
||||
if not providers_available:
|
||||
raise NotImplementedError('This model does not exist.')
|
||||
raise NotImplementedError(f'The model "{payload["model"]}" is not available. MODEl_UNAVAILABLE')
|
||||
|
||||
provider = random.choice(providers_available)
|
||||
target = provider.chat_completion(**payload)
|
||||
|
|
|
@ -5,7 +5,7 @@ import proxies
|
|||
import provider_auth
|
||||
import load_balancing
|
||||
|
||||
async def is_safe(inp) -> bool:
|
||||
async def is_policy_violated(inp) -> bool:
|
||||
text = inp
|
||||
|
||||
if isinstance(inp, list):
|
||||
|
@ -35,17 +35,21 @@ async def is_safe(inp) -> bool:
|
|||
headers=req.get('headers'),
|
||||
cookies=req.get('cookies'),
|
||||
ssl=False,
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
timeout=aiohttp.ClientTimeout(total=2),
|
||||
) as res:
|
||||
res.raise_for_status()
|
||||
json_response = await res.json()
|
||||
categories = json_response['results'][0]['category_scores']
|
||||
|
||||
return not json_response['results'][0]['flagged']
|
||||
if json_response['results'][0]['flagged']:
|
||||
return max(categories, key=categories.get)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
await provider_auth.invalidate_key(req.get('provider_auth'))
|
||||
# await provider_auth.invalidate_key(req.get('provider_auth'))
|
||||
print('[!] moderation error:', type(exc), exc)
|
||||
continue
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(asyncio.run(is_safe('I wanna kill myself')))
|
||||
print(asyncio.run(is_policy_violated('I wanna kill myself')))
|
||||
|
|
|
@ -69,24 +69,27 @@ async def handle(incoming_request):
|
|||
costs = credits_config['costs']
|
||||
cost = costs['other']
|
||||
|
||||
is_safe = True
|
||||
policy_violation = False
|
||||
|
||||
if 'chat/completions' in path:
|
||||
for model_name, model_cost in costs['chat-models'].items():
|
||||
if model_name in payload['model']:
|
||||
cost = model_cost
|
||||
|
||||
is_safe = await moderation.is_safe(payload['messages'])
|
||||
policy_violation = await moderation.is_policy_violated(payload['messages'])
|
||||
|
||||
elif '/moderations' in path:
|
||||
pass
|
||||
|
||||
else:
|
||||
inp = payload.get('input', payload.get('prompt'))
|
||||
|
||||
if inp:
|
||||
if len(inp) > 2 and not inp.isnumeric():
|
||||
is_safe = await moderation.is_safe(inp)
|
||||
policy_violation = await moderation.is_policy_violated(inp)
|
||||
|
||||
if not is_safe and not '/moderations' in path:
|
||||
error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.')
|
||||
if policy_violation:
|
||||
error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
|
||||
return error
|
||||
|
||||
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
|
||||
|
|
Loading…
Reference in a new issue