diff --git a/api/config/config.yml b/api/config/config.yml index 84cc9c8..67a971a 100644 --- a/api/config/config.yml +++ b/api/config/config.yml @@ -8,23 +8,21 @@ costs: chat-models: gpt-4-32k: 100 gpt-4: 30 - gpt-3: 10 + gpt-3: 5 ## Roles Explanation # Bonuses: They are a multiplier for costs # They work like: final_cost = cost * bonus -# Rate limits: Limit the requests of the user -# Seconds to wait between requests roles: owner: - bonus: 0.1 + bonus: 0 admin: - bonus: 0.3 + bonus: 0.2 helper: bonus: 0.4 booster: - bonus: 0.5 + bonus: 0.6 default: bonus: 1.0 diff --git a/api/core.py b/api/core.py index e343e04..64c1ef3 100644 --- a/api/core.py +++ b/api/core.py @@ -3,14 +3,11 @@ import os import sys -from helpers import errors - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(project_root) # the code above is to allow importing from the root folder -import os import json import hmac import fastapi @@ -20,6 +17,7 @@ from dotenv import load_dotenv import checks.client +from helpers import errors from db.users import UserManager load_dotenv() @@ -64,11 +62,13 @@ async def new_user_webhook(user: dict) -> None: color=0x90ee90, ) + dc = user['auth']['discord'] + embed.add_field(name='ID', value=str(user['_id']), inline=False) - embed.add_field(name='Discord', value=user['auth']['discord'] or '-') + embed.add_field(name='Discord', value=dc or '-') embed.add_field(name='Github', value=user['auth']['github'] or '-') - dhook.send(embed=embed) + dhook.send(content=f'<@{dc}>', embed=embed) @router.post('/users') async def create_user(incoming_request: fastapi.Request): diff --git a/api/handler.py b/api/handler.py index bcbe01b..b0f852d 100644 --- a/api/handler.py +++ b/api/handler.py @@ -1,5 +1,6 @@ """Does quite a few checks and prepares the incoming request for the target endpoint, so it can be streamed""" +import os import json import yaml import time @@ -23,6 +24,8 @@ models_list = json.load(open('models.json', encoding='utf8')) with open('config/config.yml', encoding='utf8') as f: config = yaml.safe_load(f) +moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY') + async def handle(incoming_request: fastapi.Request): """ ### Transfer a streaming response @@ -47,7 +50,7 @@ async def handle(incoming_request: fastapi.Request): received_key = incoming_request.headers.get('Authorization') if not received_key or not received_key.startswith('Bearer '): - return await errors.error(403, 'No NovaAI API key given!', 'Add \'Authorization: Bearer nv-...\' to your request headers.') + return await errors.error(401, 'No NovaAI API key given!', 'Add \'Authorization: Bearer nv-...\' to your request headers.') key_tags = '' @@ -58,7 +61,7 @@ async def handle(incoming_request: fastapi.Request): user = await users.user_by_api_key(received_key.split('Bearer ')[1].strip()) if not user or not user['status']['active']: - return await errors.error(403, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.') + return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.') if user.get('auth', {}).get('discord'): print(f'[bold green]>Discord[/bold green] {user["auth"]["discord"]}') @@ -86,7 +89,7 @@ async def handle(incoming_request: fastapi.Request): return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') - if not 'DISABLE_VARS' in key_tags: + if 'DISABLE_VARS' not in key_tags: payload_with_vars = json.dumps(payload) replace_dict = { @@ -112,26 +115,30 @@ async def handle(incoming_request: fastapi.Request): payload = json.loads(payload_with_vars) policy_violation = False - if '/moderations' not in path: - inp = '' - if 'input' in payload or 'prompt' in payload: - inp = payload.get('input', payload.get('prompt', '')) + if not (moderation_debug_key_key and moderation_debug_key_key in key_tags and 'gpt-3' in payload.get('model', '')): + if '/moderations' not in path: + inp = '' - if isinstance(payload.get('messages'), list): - inp = '\n'.join([message['content'] for message in payload['messages']]) + if 'input' in payload or 'prompt' in payload: + inp = payload.get('input', payload.get('prompt', '')) - if inp and len(inp) > 2 and not inp.isnumeric(): - policy_violation = await moderation.is_policy_violated(inp) + if isinstance(payload.get('messages'), list): + inp = '\n'.join([message['content'] for message in payload['messages']]) + + if inp and len(inp) > 2 and not inp.isnumeric(): + policy_violation = await moderation.is_policy_violated(inp) if policy_violation: return await errors.error( - 400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', + 400, f'The request contains content which violates this model\'s policies for <{policy_violation}>.', 'We currently don\'t support any NSFW models.' ) if 'chat/completions' in path and not payload.get('stream', False): payload['stream'] = False + if 'chat/completions' in path and not payload.get('model'): + payload['model'] = 'gpt-3.5-turbo' media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json' diff --git a/api/load_balancing.py b/api/load_balancing.py index bc781ea..0e18e02 100644 --- a/api/load_balancing.py +++ b/api/load_balancing.py @@ -27,11 +27,11 @@ async def balance_chat_request(payload: dict) -> dict: providers_available.append(provider_module) if not providers_available: - raise NotImplementedError(f'The model "{payload["model"]}" is not available. MODEl_UNAVAILABLE') + raise ValueError(f'The model "{payload["model"]}" is not available. MODEL_UNAVAILABLE') provider = random.choice(providers_available) - target = provider.chat_completion(**payload) - + target = await provider.chat_completion(**payload) + module_name = await _get_module_name(provider) target['module'] = module_name @@ -61,7 +61,7 @@ async def balance_organic_request(request: dict) -> dict: providers_available.append(provider_module) provider = random.choice(providers_available) - target = provider.organify(request) + target = await provider.organify(request) module_name = await _get_module_name(provider) target['module'] = module_name diff --git a/api/provider_auth.py b/api/provider_auth.py index 881fa8f..72d2cfd 100644 --- a/api/provider_auth.py +++ b/api/provider_auth.py @@ -1,7 +1,28 @@ """This module contains functions for authenticating with providers.""" +import os import asyncio +from dotenv import load_dotenv +from dhooks import Webhook, Embed + +load_dotenv() + +async def invalidation_webhook(provider_and_key: str) -> None: + """Runs when a new user is created.""" + + dhook = Webhook(os.environ['DISCORD_WEBHOOK__API_ISSUE']) + + embed = Embed( + description='Key Invalidated', + color=0xffee90, + ) + + embed.add_field(name='Provider', value=provider_and_key.split('>')[0]) + embed.add_field(name='Key (censored)', value=f'||{provider_and_key.split(">")[1][:10]}...||', inline=False) + + dhook.send(embed=embed) + async def invalidate_key(provider_and_key: str) -> None: """ @@ -28,5 +49,7 @@ async def invalidate_key(provider_and_key: str) -> None: with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f: f.write(key + '\n') + await invalidation_webhook(provider_and_key) + if __name__ == '__main__': - asyncio.run(invalidate_key('closed>cd...')) + asyncio.run(invalidate_key('closed>demo-...')) diff --git a/checks/client.py b/checks/client.py index 1b04de8..7c232e9 100644 --- a/checks/client.py +++ b/checks/client.py @@ -140,8 +140,8 @@ async def demo(): print('[lightblue]Checking if the API works...') print(await test_chat()) - print('[lightblue]Checking if SDXL image generation works...') - print(await test_sdxl()) + # print('[lightblue]Checking if SDXL image generation works...') + # print(await test_sdxl()) print('[lightblue]Checking if the moderation endpoint works...') print(await test_api_moderation()) diff --git a/setup.md b/setup.md index 1b60c80..f47009d 100644 --- a/setup.md +++ b/setup.md @@ -1,3 +1,13 @@ +# Setup +## Requirements +- Python 3.9+ +- pip +- MongoDB database + +## Recommended +- `git` (for updates) +- `screen` (for production) +- Cloudflare (for security, anti-DDoS, etc.) - we fully support Cloudflare ## Install Assuming you have a new version of Python 3.9+ and pip installed: @@ -110,6 +120,8 @@ You can also specify a port, e.g.: python run 1337 ``` +## Adding a provider + ## Test if it works `python checks`