diff --git a/api/chunks.py b/api/chunks.py new file mode 100644 index 0000000..7ad2752 --- /dev/null +++ b/api/chunks.py @@ -0,0 +1,30 @@ +import json + +from helpers import chat + +async def process_chunks( + chunks, + is_chat: bool, + chat_id: int, + target_request: dict, + model: str=None, +): + """This function processes the response chunks from the providers and yields them. + """ + async for chunk in chunks: + chunk = chunk.decode("utf8").strip() + send = False + + if is_chat and '{' in chunk: + data = json.loads(chunk.split('data: ')[1]) + chunk = chunk.replace(data['id'], chat_id) + send = True + + if target_request['module'] == 'twa' and data.get('text'): + chunk = await chat.create_chat_chunk(chat_id=chat_id, model=model, content=['text']) + + if (not data['choices'][0]['delta']) or data['choices'][0]['delta'] == {'role': 'assistant'}: + send = False + + if send and chunk: + yield chunk + '\n\n' diff --git a/api/config/config.yml b/api/config/config.yml index 593392c..84cc9c8 100644 --- a/api/config/config.yml +++ b/api/config/config.yml @@ -6,54 +6,25 @@ costs: other: 10 chat-models: - gpt-3: 10 - gpt-4: 30 gpt-4-32k: 100 + gpt-4: 30 + gpt-3: 10 ## Roles Explanation # Bonuses: They are a multiplier for costs # They work like: final_cost = cost * bonus # Rate limits: Limit the requests of the user -# The rate limit is by how many seconds until a new request can be done. - -## TODO: Setup proper rate limit settings for each role -## Current settings are: -## **NOT MEANT FOR PRODUCTION. DO NOT USE WITH THESE SETTINGS.** +# Seconds to wait between requests roles: owner: bonus: 0.1 - rate_limit: - other: 60 - gpt-3: 60 - gpt-4: 35 - gpt-4-32k: 5 admin: bonus: 0.3 - rate_limit: - other: 60 - gpt-3: 60 - gpt-4: 30 - gpt-4-32k: 4 helper: bonus: 0.4 - rate_limit: - other: 60 - gpt-3: 60 - gpt-4: 25 - gpt-4-32k: 3 booster: bonus: 0.5 - rate_limit: - other: 60 - gpt-3: 60 - gpt-4: 20 - gpt-4-32k: 2 default: - bonus: 0 - rate_limit: - other: 60 - gpt-3: 60 - gpt-4: 15 - gpt-4-32k: 1 \ No newline at end of file + bonus: 1.0 diff --git a/api/core.py b/api/core.py index 25f5e1d..8294c44 100644 --- a/api/core.py +++ b/api/core.py @@ -71,3 +71,23 @@ async def create_user(incoming_request: fastapi.Request): await new_user_webhook(user) return user + +@router.put('/users') +async def update_user(incoming_request: fastapi.Request): + auth_error = await check_core_auth(incoming_request) + + if auth_error: + return auth_error + + try: + payload = await incoming_request.json() + discord_id = payload.get('discord_id') + updates = payload.get('updates') + except (json.decoder.JSONDecodeError, AttributeError): + return fastapi.Response(status_code=400, content='Invalid or no payload received.') + + # Update the user + manager = UserManager() + user = await manager.update_by_discord_id(discord_id, updates) + + return user diff --git a/api/db/users.py b/api/db/users.py index 63073d5..f6958f8 100644 --- a/api/db/users.py +++ b/api/db/users.py @@ -83,6 +83,10 @@ class UserManager: db = await self._get_collection('users') return await db.update_one({'_id': user_id}, update) + async def upate_by_discord_id(self, discord_id: str, update): + db = await self._get_collection('users') + return await db.update_one({'auth.discord': str(int(discord_id))}, update) + async def update_by_filter(self, obj_filter, update): db = await self._get_collection('users') return await db.update_one(obj_filter, update) diff --git a/api/provider_auth.py b/api/provider_auth.py index a6aaed3..881fa8f 100644 --- a/api/provider_auth.py +++ b/api/provider_auth.py @@ -8,7 +8,7 @@ async def invalidate_key(provider_and_key: str) -> None: Invalidates a key stored in the secret/ folder by storing it in the associated .invalid.txt file. The schmea in which should be passed is: , e.g. - closed4>sk-... + closed4>cd-... """ @@ -29,4 +29,4 @@ async def invalidate_key(provider_and_key: str) -> None: f.write(key + '\n') if __name__ == '__main__': - asyncio.run(invalidate_key('closed>sk-...')) + asyncio.run(invalidate_key('closed>cd...')) diff --git a/api/streaming.py b/api/streaming.py index 4ff7ecc..819c89b 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -2,16 +2,17 @@ import os import json +import yaml import dhooks import asyncio import aiohttp import starlette -import datetime from rich import print from dotenv import load_dotenv from python_socks._errors import ProxyError +import chunks import proxies import provider_auth import load_balancing @@ -20,8 +21,6 @@ from db import logs from db.users import UserManager from db.stats import StatsManager from helpers import network, chat, errors -import yaml - load_dotenv() @@ -43,33 +42,6 @@ DEMO_PAYLOAD = { ] } -async def process_response(response, is_chat, chat_id, model, target_request): - """Proccesses chunks from streaming - - Args: - response (_type_): The response - is_chat (bool): If there is 'chat/completions' in path - chat_id (_type_): ID of chat with bot - model (_type_): What AI model it is - """ - async for chunk in response.content.iter_any(): - chunk = chunk.decode("utf8").strip() - send = False - - if is_chat and '{' in chunk: - data = json.loads(chunk.split('data: ')[1]) - chunk = chunk.replace(data['id'], chat_id) - send = True - - if target_request['module'] == 'twa' and data.get('text'): - chunk = await chat.create_chat_chunk(chat_id=chat_id, model=model, content=['text']) - - if (not data['choices'][0]['delta']) or data['choices'][0]['delta'] == {'role': 'assistant'}: - send = False - - if send and chunk: - yield chunk + '\n\n' - async def stream( path: str='/v1/chat/completions', user: dict=None, @@ -80,32 +52,8 @@ async def stream( ): """Stream the completions request. Sends data in chunks If not streaming, it sends the result in its entirety. - - Args: - path (str, optional): URL Path. Defaults to '/v1/chat/completions'. - user (dict, optional): User object (dict) Defaults to None. - payload (dict, optional): Payload. Defaults to None. - credits_cost (int, optional): Cost of the credits of the request. Defaults to 0. - input_tokens (int, optional): Total tokens calculated with tokenizer. Defaults to 0. - incoming_request (starlette.requests.Request, optional): Incoming request. Defaults to None. """ - ## Rate limits user. - # If rate limit is exceeded, error code 429. Otherwise, lets the user pass but notes down - # last request time for future requests. - if user: - role = user.get('role', 'default') - rate_limit = config['roles'].get(role, 1)['rate_limit'].get(payload['model'], 1) - - last_request_time = user_last_request_time.get(user['api_key']) - time_since_last_request = datetime.now() - last_request_time - - if time_since_last_request < datetime.timedelta(seconds=rate_limit): - yield await errors.yield_error(429, "Rate limit exceeded', 'You are making requests too quickly. Please wait and try again later. Ask a administrator if you think this shouldn't happen. ") - return - else: - user_last_request_time[user['_id']] = datetime.now() - ## Setup managers db = UserManager() stats = StatsManager() @@ -127,11 +75,9 @@ async def stream( for _ in range(5): headers = {'Content-Type': 'application/json'} - - # Load balancing + # Load balancing: randomly selecting a suitable provider # If the request is a chat completion, then we need to load balance between chat providers # If the request is an organic request, then we need to load balance between organic providers - try: if is_chat: target_request = await load_balancing.balance_chat_request(payload) @@ -191,7 +137,13 @@ async def stream( if 'Too Many Requests' in str(exc): continue - async for chunk in process_response(response, is_chat, chat_id, model, target_request): + async for chunk in chunks.process_chunks( + chunks=response.content.iter_any(), + is_chat=is_chat, + chat_id=chat_id, + model=model, + target_request=target_request + ): yield chunk break diff --git a/api/transfer.py b/api/transfer.py index 98c3093..9d4c73d 100644 --- a/api/transfer.py +++ b/api/transfer.py @@ -28,12 +28,6 @@ async def handle(incoming_request): users = UserManager() path = incoming_request.url.path.replace('v1/v1/', 'v1/') - allowed_methods = {'GET', 'POST', 'PUT', 'DELETE', 'PATCH'} - method = incoming_request.method - - if method not in allowed_methods: - return await errors.error(405, f'Method "{method}" is not allowed.', 'Change the request method to the correct one.') - try: payload = await incoming_request.json() except json.decoder.JSONDecodeError: @@ -78,7 +72,12 @@ async def handle(incoming_request): return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') role = user.get('role', 'default') - role_cost_multiplier = config['roles'].get(role, 1)['bonus'] + + try: + role_cost_multiplier = config['roles'][role]['bonus'] + except KeyError: + role_cost_multiplier = 1 + cost = round(cost * role_cost_multiplier) if user['credits'] < cost: diff --git a/tests/__main__.py b/tests/__main__.py index 9bdf8f0..5c6e91a 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -63,7 +63,7 @@ def test_library(): def test_library_moderation(): try: return closedai.Moderation.create('I wanna kill myself, I wanna kill myself; It\'s all I hear right now, it\'s all I hear right now') - except closedai.errors.InvalidRequestError as exc: + except closedai.error.InvalidRequestError: return True def test_models(): @@ -108,7 +108,6 @@ def test_all(): print(test_models()) if __name__ == '__main__': - api_endpoint = 'https://alpha-api.nova-oss.com/v1' closedai.api_base = api_endpoint closedai.api_key = os.getenv('TEST_NOVA_KEY')