mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 19:23:58 +01:00
Fixed some several issues with moderation, models etc.
This commit is contained in:
parent
1a3e275a1c
commit
ce24c3a5a2
|
@ -24,7 +24,7 @@ async def balance_chat_request(payload: dict) -> dict:
|
||||||
providers_available.append(provider_module)
|
providers_available.append(provider_module)
|
||||||
|
|
||||||
if not providers_available:
|
if not providers_available:
|
||||||
raise NotImplementedError('This model does not exist.')
|
raise NotImplementedError(f'The model "{payload["model"]}" is not available. MODEl_UNAVAILABLE')
|
||||||
|
|
||||||
provider = random.choice(providers_available)
|
provider = random.choice(providers_available)
|
||||||
target = provider.chat_completion(**payload)
|
target = provider.chat_completion(**payload)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import proxies
|
||||||
import provider_auth
|
import provider_auth
|
||||||
import load_balancing
|
import load_balancing
|
||||||
|
|
||||||
async def is_safe(inp) -> bool:
|
async def is_policy_violated(inp) -> bool:
|
||||||
text = inp
|
text = inp
|
||||||
|
|
||||||
if isinstance(inp, list):
|
if isinstance(inp, list):
|
||||||
|
@ -35,17 +35,21 @@ async def is_safe(inp) -> bool:
|
||||||
headers=req.get('headers'),
|
headers=req.get('headers'),
|
||||||
cookies=req.get('cookies'),
|
cookies=req.get('cookies'),
|
||||||
ssl=False,
|
ssl=False,
|
||||||
timeout=aiohttp.ClientTimeout(total=5),
|
timeout=aiohttp.ClientTimeout(total=2),
|
||||||
) as res:
|
) as res:
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
json_response = await res.json()
|
json_response = await res.json()
|
||||||
|
categories = json_response['results'][0]['category_scores']
|
||||||
|
|
||||||
return not json_response['results'][0]['flagged']
|
if json_response['results'][0]['flagged']:
|
||||||
|
return max(categories, key=categories.get)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
await provider_auth.invalidate_key(req.get('provider_auth'))
|
# await provider_auth.invalidate_key(req.get('provider_auth'))
|
||||||
print('[!] moderation error:', type(exc), exc)
|
print('[!] moderation error:', type(exc), exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print(asyncio.run(is_safe('I wanna kill myself')))
|
print(asyncio.run(is_policy_violated('I wanna kill myself')))
|
||||||
|
|
|
@ -69,24 +69,27 @@ async def handle(incoming_request):
|
||||||
costs = credits_config['costs']
|
costs = credits_config['costs']
|
||||||
cost = costs['other']
|
cost = costs['other']
|
||||||
|
|
||||||
is_safe = True
|
policy_violation = False
|
||||||
|
|
||||||
if 'chat/completions' in path:
|
if 'chat/completions' in path:
|
||||||
for model_name, model_cost in costs['chat-models'].items():
|
for model_name, model_cost in costs['chat-models'].items():
|
||||||
if model_name in payload['model']:
|
if model_name in payload['model']:
|
||||||
cost = model_cost
|
cost = model_cost
|
||||||
|
|
||||||
is_safe = await moderation.is_safe(payload['messages'])
|
policy_violation = await moderation.is_policy_violated(payload['messages'])
|
||||||
|
|
||||||
|
elif '/moderations' in path:
|
||||||
|
pass
|
||||||
|
|
||||||
else:
|
else:
|
||||||
inp = payload.get('input', payload.get('prompt'))
|
inp = payload.get('input', payload.get('prompt'))
|
||||||
|
|
||||||
if inp:
|
if inp:
|
||||||
if len(inp) > 2 and not inp.isnumeric():
|
if len(inp) > 2 and not inp.isnumeric():
|
||||||
is_safe = await moderation.is_safe(inp)
|
policy_violation = await moderation.is_policy_violated(inp)
|
||||||
|
|
||||||
if not is_safe and not '/moderations' in path:
|
if policy_violation:
|
||||||
error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.')
|
error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
|
||||||
return error
|
return error
|
||||||
|
|
||||||
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
|
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
|
||||||
|
|
Loading…
Reference in a new issue