diff --git a/api/handler.py b/api/handler.py index b4911b6..168aa09 100644 --- a/api/handler.py +++ b/api/handler.py @@ -71,10 +71,10 @@ async def handle(incoming_request: fastapi.Request): if ban_reason: return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.') - # Checking for enterprise status - enterprise_keys = os.environ.get('ENTERPRISE_KEYS') - if path.startswith('/enterprise/v1') and user.get('api_key') not in enterprise_keys.split(): - return await errors.error(403, 'Enterprise API is not available.', 'Contact the staff for an upgrade.') + is_enterprise_key = 'enterprise' in user.get('role', 'default') + + if path.startswith('/enterprise/v1') and not is_enterprise_key: + return await errors.error(403, 'Enterprise API is not available for your API key.', 'Contact the staff for an upgrade.') if 'account/credits' in path: return fastapi.responses.JSONResponse({'credits': user['credits']}) @@ -87,10 +87,13 @@ async def handle(incoming_request: fastapi.Request): role = user.get('role', 'default') - try: - role_cost_multiplier = config['roles'][role]['bonus'] - except KeyError: - role_cost_multiplier = 1 + if 'enterprise' in role: + role_cost_multiplier = 0.1 + else: + try: + role_cost_multiplier = config['roles'][role]['bonus'] + except KeyError: + role_cost_multiplier = 1 cost = round(cost * role_cost_multiplier)