From def26f91046113130abb6bd4c49817cf2d1c2b9c Mon Sep 17 00:00:00 2001 From: Game_Time <108236317+RayBytes@users.noreply.github.com> Date: Sun, 13 Aug 2023 21:19:56 +0500 Subject: [PATCH] clean up code a lot --- api/transfer.py | 66 ++++++++++++++++--------------------------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/api/transfer.py b/api/transfer.py index 157c68a..85e83a9 100644 --- a/api/transfer.py +++ b/api/transfer.py @@ -25,86 +25,62 @@ async def handle(incoming_request): Takes the request from the incoming request to the target endpoint. Checks method, token amount, auth and cost along with if request is NSFW. """ - + path = incoming_request.url.path.replace('v1/v1/', 'v1/') - # METHOD - if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']: - return await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.') + allowed_methods = {'GET', 'POST', 'PUT', 'DELETE', 'PATCH'} + method = incoming_request.method - # PAYLOAD - try: - payload = await incoming_request.json() - except json.decoder.JSONDecodeError: - payload = {} + if method not in allowed_methods: + return await errors.error(405, f'Method "{method}" is not allowed.', 'Change the request method to the correct one.') + + payload = await incoming_request.json() - # Tokenise w/ tiktoken try: - input_tokens = await tokens.count_for_messages(payload['messages']) + input_tokens = await tokens.count_for_messages(payload.get('messages', [])) except (KeyError, TypeError): input_tokens = 0 - # Check user auth received_key = incoming_request.headers.get('Authorization') - if not received_key: + if not received_key or not received_key.startswith('Bearer '): return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.') - if received_key.startswith('Bearer '): - received_key = received_key.split('Bearer ')[1] + user = await users.by_api_key(received_key.split('Bearer ')[1].strip()) - user = await users.by_api_key(received_key.strip()) - - if not user: - return await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.') + if not user or not user['status']['active']: + return await errors.error(401, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.') ban_reason = user['status']['ban_reason'] if ban_reason: return await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.') - if not user['status']['active']: - return await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.') - - if '/models' in path: + path_contains_models = '/models' in path + if path_contains_models: return fastapi.responses.JSONResponse(content=models_list) - # Calculate cost of tokens & check for nsfw prompts costs = credits_config['costs'] cost = costs['other'] - policy_violation = False - if 'chat/completions' in path: - for model_name, model_cost in costs['chat-models'].items(): - if model_name in payload['model']: - cost = model_cost + cost = costs['chat-models'].get(payload.get('model'), cost) - policy_violation = await moderation.is_policy_violated(payload['messages']) - - elif '/moderations' in path: - pass - - else: - inp = payload.get('input', payload.get('prompt')) - - if inp: - if len(inp) > 2 and not inp.isnumeric(): - policy_violation = await moderation.is_policy_violated(inp) + policy_violation = False + if 'chat/completions' in path or ('input' in payload or 'prompt' in payload): + inp = payload.get('input', payload.get('prompt', '')) + if inp and len(inp) > 2 and not inp.isnumeric(): + policy_violation = await moderation.is_policy_violated(inp) if policy_violation: return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') - role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) cost = round(cost * role_cost_multiplier) if user['credits'] < cost: return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') - - # Send the completion request - - if 'chat/completions' in path and not payload.get('stream') is True: + if 'chat/completions' in path and not payload.get('stream', False): payload['stream'] = False media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json'