clean up code a lot

This commit is contained in:
Game_Time 2023-08-13 21:19:56 +05:00 committed by GitHub
parent 8e70c25ee0
commit def26f9104
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -25,86 +25,62 @@ async def handle(incoming_request):
Takes the request from the incoming request to the target endpoint. Takes the request from the incoming request to the target endpoint.
Checks method, token amount, auth and cost along with if request is NSFW. Checks method, token amount, auth and cost along with if request is NSFW.
""" """
path = incoming_request.url.path.replace('v1/v1/', 'v1/') path = incoming_request.url.path.replace('v1/v1/', 'v1/')
# METHOD allowed_methods = {'GET', 'POST', 'PUT', 'DELETE', 'PATCH'}
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']: method = incoming_request.method
return await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
# PAYLOAD if method not in allowed_methods:
try: return await errors.error(405, f'Method "{method}" is not allowed.', 'Change the request method to the correct one.')
payload = await incoming_request.json()
except json.decoder.JSONDecodeError: payload = await incoming_request.json()
payload = {}
# Tokenise w/ tiktoken
try: try:
input_tokens = await tokens.count_for_messages(payload['messages']) input_tokens = await tokens.count_for_messages(payload.get('messages', []))
except (KeyError, TypeError): except (KeyError, TypeError):
input_tokens = 0 input_tokens = 0
# Check user auth
received_key = incoming_request.headers.get('Authorization') received_key = incoming_request.headers.get('Authorization')
if not received_key: if not received_key or not received_key.startswith('Bearer '):
return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.') return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
if received_key.startswith('Bearer '): user = await users.by_api_key(received_key.split('Bearer ')[1].strip())
received_key = received_key.split('Bearer ')[1]
user = await users.by_api_key(received_key.strip()) if not user or not user['status']['active']:
return await errors.error(401, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.')
if not user:
return await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
ban_reason = user['status']['ban_reason'] ban_reason = user['status']['ban_reason']
if ban_reason: if ban_reason:
return await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.') return await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
if not user['status']['active']: path_contains_models = '/models' in path
return await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.') if path_contains_models:
if '/models' in path:
return fastapi.responses.JSONResponse(content=models_list) return fastapi.responses.JSONResponse(content=models_list)
# Calculate cost of tokens & check for nsfw prompts
costs = credits_config['costs'] costs = credits_config['costs']
cost = costs['other'] cost = costs['other']
policy_violation = False
if 'chat/completions' in path: if 'chat/completions' in path:
for model_name, model_cost in costs['chat-models'].items(): cost = costs['chat-models'].get(payload.get('model'), cost)
if model_name in payload['model']:
cost = model_cost
policy_violation = await moderation.is_policy_violated(payload['messages']) policy_violation = False
if 'chat/completions' in path or ('input' in payload or 'prompt' in payload):
elif '/moderations' in path: inp = payload.get('input', payload.get('prompt', ''))
pass if inp and len(inp) > 2 and not inp.isnumeric():
policy_violation = await moderation.is_policy_violated(inp)
else:
inp = payload.get('input', payload.get('prompt'))
if inp:
if len(inp) > 2 and not inp.isnumeric():
policy_violation = await moderation.is_policy_violated(inp)
if policy_violation: if policy_violation:
return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
cost = round(cost * role_cost_multiplier) cost = round(cost * role_cost_multiplier)
if user['credits'] < cost: if user['credits'] < cost:
return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.')
if 'chat/completions' in path and not payload.get('stream', False):
# Send the completion request
if 'chat/completions' in path and not payload.get('stream') is True:
payload['stream'] = False payload['stream'] = False
media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json' media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json'