I don't even fkin know lmfao

This commit is contained in:
nsde 2023-08-18 21:23:00 +02:00
parent 0da45b0eeb
commit c07af0aed8
8 changed files with 77 additions and 102 deletions

30
api/chunks.py Normal file
View file

@ -0,0 +1,30 @@
import json
from helpers import chat
async def process_chunks(
chunks,
is_chat: bool,
chat_id: int,
target_request: dict,
model: str=None,
):
"""This function processes the response chunks from the providers and yields them.
"""
async for chunk in chunks:
chunk = chunk.decode("utf8").strip()
send = False
if is_chat and '{' in chunk:
data = json.loads(chunk.split('data: ')[1])
chunk = chunk.replace(data['id'], chat_id)
send = True
if target_request['module'] == 'twa' and data.get('text'):
chunk = await chat.create_chat_chunk(chat_id=chat_id, model=model, content=['text'])
if (not data['choices'][0]['delta']) or data['choices'][0]['delta'] == {'role': 'assistant'}:
send = False
if send and chunk:
yield chunk + '\n\n'

View file

@ -6,54 +6,25 @@ costs:
other: 10 other: 10
chat-models: chat-models:
gpt-3: 10
gpt-4: 30
gpt-4-32k: 100 gpt-4-32k: 100
gpt-4: 30
gpt-3: 10
## Roles Explanation ## Roles Explanation
# Bonuses: They are a multiplier for costs # Bonuses: They are a multiplier for costs
# They work like: final_cost = cost * bonus # They work like: final_cost = cost * bonus
# Rate limits: Limit the requests of the user # Rate limits: Limit the requests of the user
# The rate limit is by how many seconds until a new request can be done. # Seconds to wait between requests
## TODO: Setup proper rate limit settings for each role
## Current settings are:
## **NOT MEANT FOR PRODUCTION. DO NOT USE WITH THESE SETTINGS.**
roles: roles:
owner: owner:
bonus: 0.1 bonus: 0.1
rate_limit:
other: 60
gpt-3: 60
gpt-4: 35
gpt-4-32k: 5
admin: admin:
bonus: 0.3 bonus: 0.3
rate_limit:
other: 60
gpt-3: 60
gpt-4: 30
gpt-4-32k: 4
helper: helper:
bonus: 0.4 bonus: 0.4
rate_limit:
other: 60
gpt-3: 60
gpt-4: 25
gpt-4-32k: 3
booster: booster:
bonus: 0.5 bonus: 0.5
rate_limit:
other: 60
gpt-3: 60
gpt-4: 20
gpt-4-32k: 2
default: default:
bonus: 0 bonus: 1.0
rate_limit:
other: 60
gpt-3: 60
gpt-4: 15
gpt-4-32k: 1

View file

@ -71,3 +71,23 @@ async def create_user(incoming_request: fastapi.Request):
await new_user_webhook(user) await new_user_webhook(user)
return user return user
@router.put('/users')
async def update_user(incoming_request: fastapi.Request):
auth_error = await check_core_auth(incoming_request)
if auth_error:
return auth_error
try:
payload = await incoming_request.json()
discord_id = payload.get('discord_id')
updates = payload.get('updates')
except (json.decoder.JSONDecodeError, AttributeError):
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
# Update the user
manager = UserManager()
user = await manager.update_by_discord_id(discord_id, updates)
return user

View file

@ -83,6 +83,10 @@ class UserManager:
db = await self._get_collection('users') db = await self._get_collection('users')
return await db.update_one({'_id': user_id}, update) return await db.update_one({'_id': user_id}, update)
async def upate_by_discord_id(self, discord_id: str, update):
db = await self._get_collection('users')
return await db.update_one({'auth.discord': str(int(discord_id))}, update)
async def update_by_filter(self, obj_filter, update): async def update_by_filter(self, obj_filter, update):
db = await self._get_collection('users') db = await self._get_collection('users')
return await db.update_one(obj_filter, update) return await db.update_one(obj_filter, update)

View file

@ -8,7 +8,7 @@ async def invalidate_key(provider_and_key: str) -> None:
Invalidates a key stored in the secret/ folder by storing it in the associated .invalid.txt file. Invalidates a key stored in the secret/ folder by storing it in the associated .invalid.txt file.
The schmea in which <provider_and_key> should be passed is: The schmea in which <provider_and_key> should be passed is:
<provider_name><key>, e.g. <provider_name><key>, e.g.
closed4>sk-... closed4>cd-...
""" """
@ -29,4 +29,4 @@ async def invalidate_key(provider_and_key: str) -> None:
f.write(key + '\n') f.write(key + '\n')
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(invalidate_key('closed>sk-...')) asyncio.run(invalidate_key('closed>cd...'))

View file

@ -2,16 +2,17 @@
import os import os
import json import json
import yaml
import dhooks import dhooks
import asyncio import asyncio
import aiohttp import aiohttp
import starlette import starlette
import datetime
from rich import print from rich import print
from dotenv import load_dotenv from dotenv import load_dotenv
from python_socks._errors import ProxyError from python_socks._errors import ProxyError
import chunks
import proxies import proxies
import provider_auth import provider_auth
import load_balancing import load_balancing
@ -20,8 +21,6 @@ from db import logs
from db.users import UserManager from db.users import UserManager
from db.stats import StatsManager from db.stats import StatsManager
from helpers import network, chat, errors from helpers import network, chat, errors
import yaml
load_dotenv() load_dotenv()
@ -43,33 +42,6 @@ DEMO_PAYLOAD = {
] ]
} }
async def process_response(response, is_chat, chat_id, model, target_request):
"""Proccesses chunks from streaming
Args:
response (_type_): The response
is_chat (bool): If there is 'chat/completions' in path
chat_id (_type_): ID of chat with bot
model (_type_): What AI model it is
"""
async for chunk in response.content.iter_any():
chunk = chunk.decode("utf8").strip()
send = False
if is_chat and '{' in chunk:
data = json.loads(chunk.split('data: ')[1])
chunk = chunk.replace(data['id'], chat_id)
send = True
if target_request['module'] == 'twa' and data.get('text'):
chunk = await chat.create_chat_chunk(chat_id=chat_id, model=model, content=['text'])
if (not data['choices'][0]['delta']) or data['choices'][0]['delta'] == {'role': 'assistant'}:
send = False
if send and chunk:
yield chunk + '\n\n'
async def stream( async def stream(
path: str='/v1/chat/completions', path: str='/v1/chat/completions',
user: dict=None, user: dict=None,
@ -80,32 +52,8 @@ async def stream(
): ):
"""Stream the completions request. Sends data in chunks """Stream the completions request. Sends data in chunks
If not streaming, it sends the result in its entirety. If not streaming, it sends the result in its entirety.
Args:
path (str, optional): URL Path. Defaults to '/v1/chat/completions'.
user (dict, optional): User object (dict) Defaults to None.
payload (dict, optional): Payload. Defaults to None.
credits_cost (int, optional): Cost of the credits of the request. Defaults to 0.
input_tokens (int, optional): Total tokens calculated with tokenizer. Defaults to 0.
incoming_request (starlette.requests.Request, optional): Incoming request. Defaults to None.
""" """
## Rate limits user.
# If rate limit is exceeded, error code 429. Otherwise, lets the user pass but notes down
# last request time for future requests.
if user:
role = user.get('role', 'default')
rate_limit = config['roles'].get(role, 1)['rate_limit'].get(payload['model'], 1)
last_request_time = user_last_request_time.get(user['api_key'])
time_since_last_request = datetime.now() - last_request_time
if time_since_last_request < datetime.timedelta(seconds=rate_limit):
yield await errors.yield_error(429, "Rate limit exceeded', 'You are making requests too quickly. Please wait and try again later. Ask a administrator if you think this shouldn't happen. ")
return
else:
user_last_request_time[user['_id']] = datetime.now()
## Setup managers ## Setup managers
db = UserManager() db = UserManager()
stats = StatsManager() stats = StatsManager()
@ -127,11 +75,9 @@ async def stream(
for _ in range(5): for _ in range(5):
headers = {'Content-Type': 'application/json'} headers = {'Content-Type': 'application/json'}
# Load balancing: randomly selecting a suitable provider
# Load balancing
# If the request is a chat completion, then we need to load balance between chat providers # If the request is a chat completion, then we need to load balance between chat providers
# If the request is an organic request, then we need to load balance between organic providers # If the request is an organic request, then we need to load balance between organic providers
try: try:
if is_chat: if is_chat:
target_request = await load_balancing.balance_chat_request(payload) target_request = await load_balancing.balance_chat_request(payload)
@ -191,7 +137,13 @@ async def stream(
if 'Too Many Requests' in str(exc): if 'Too Many Requests' in str(exc):
continue continue
async for chunk in process_response(response, is_chat, chat_id, model, target_request): async for chunk in chunks.process_chunks(
chunks=response.content.iter_any(),
is_chat=is_chat,
chat_id=chat_id,
model=model,
target_request=target_request
):
yield chunk yield chunk
break break

View file

@ -28,12 +28,6 @@ async def handle(incoming_request):
users = UserManager() users = UserManager()
path = incoming_request.url.path.replace('v1/v1/', 'v1/') path = incoming_request.url.path.replace('v1/v1/', 'v1/')
allowed_methods = {'GET', 'POST', 'PUT', 'DELETE', 'PATCH'}
method = incoming_request.method
if method not in allowed_methods:
return await errors.error(405, f'Method "{method}" is not allowed.', 'Change the request method to the correct one.')
try: try:
payload = await incoming_request.json() payload = await incoming_request.json()
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
@ -78,7 +72,12 @@ async def handle(incoming_request):
return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
role = user.get('role', 'default') role = user.get('role', 'default')
role_cost_multiplier = config['roles'].get(role, 1)['bonus']
try:
role_cost_multiplier = config['roles'][role]['bonus']
except KeyError:
role_cost_multiplier = 1
cost = round(cost * role_cost_multiplier) cost = round(cost * role_cost_multiplier)
if user['credits'] < cost: if user['credits'] < cost:

View file

@ -63,7 +63,7 @@ def test_library():
def test_library_moderation(): def test_library_moderation():
try: try:
return closedai.Moderation.create('I wanna kill myself, I wanna kill myself; It\'s all I hear right now, it\'s all I hear right now') return closedai.Moderation.create('I wanna kill myself, I wanna kill myself; It\'s all I hear right now, it\'s all I hear right now')
except closedai.errors.InvalidRequestError as exc: except closedai.error.InvalidRequestError:
return True return True
def test_models(): def test_models():
@ -108,7 +108,6 @@ def test_all():
print(test_models()) print(test_models())
if __name__ == '__main__': if __name__ == '__main__':
api_endpoint = 'https://alpha-api.nova-oss.com/v1'
closedai.api_base = api_endpoint closedai.api_base = api_endpoint
closedai.api_key = os.getenv('TEST_NOVA_KEY') closedai.api_key = os.getenv('TEST_NOVA_KEY')