some general changes

This commit is contained in:
Game_Time 2023-08-12 20:56:21 +05:00
parent 23d463fad8
commit afc3b08351
6 changed files with 37 additions and 24 deletions

View file

@ -51,7 +51,7 @@ async def create_user(incoming_request: fastapi.Request):
auth_error = await check_core_auth(incoming_request) auth_error = await check_core_auth(incoming_request)
if auth_error: if auth_error:
return auth_error return auth_error
try: try:
payload = await incoming_request.json() payload = await incoming_request.json()

View file

@ -18,6 +18,16 @@ async def create_chat_id() -> str:
return f'chatcmpl-{chat_id}' return f'chatcmpl-{chat_id}'
async def create_chat_chunk(chat_id: str, model: str, content=None) -> dict: async def create_chat_chunk(chat_id: str, model: str, content=None) -> dict:
"""Creates the chunk for streaming chat.
Args:
chat_id (str): _description_
model (str): _description_
content (_type_, optional): _description_. Defaults to None.
Returns:
dict: _description_
"""
content = content or {} content = content or {}
delta = {} delta = {}

View file

@ -1,2 +0,0 @@
class Retry(Exception):
"""The server should retry the request."""

View file

@ -1,7 +1,18 @@
import tiktoken import tiktoken
async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int: async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int:
"""Return the number of tokens used by a list of messages.""" """Return the number of tokens used by a list of messages
Args:
messages (list): _description_
model (str, optional): _description_. Defaults to 'gpt-3.5-turbo-0613'.
Raises:
NotImplementedError: _description_
Returns:
int: _description_
"""
try: try:
encoding = tiktoken.encoding_for_model(model) encoding = tiktoken.encoding_for_model(model)

View file

@ -35,7 +35,7 @@ async def balance_chat_request(payload: dict) -> dict:
return target return target
async def balance_organic_request(request: dict) -> dict: async def balance_organic_request(request: dict) -> dict:
"""Load balnace to non-chat completion request between other "organic" providers which respond in the desired format already.""" """Load balance to non-chat completion request between other "organic" providers which respond in the desired format already."""
providers_available = [] providers_available = []

View file

@ -26,8 +26,7 @@ async def handle(incoming_request):
# METHOD # METHOD
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']: if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
error = await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.') return await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
return error
# PAYLOAD # PAYLOAD
try: try:
@ -35,42 +34,37 @@ async def handle(incoming_request):
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
payload = {} payload = {}
# TOKENS # Tokenise w/ tiktoken
try: try:
input_tokens = await tokens.count_for_messages(payload['messages']) input_tokens = await tokens.count_for_messages(payload['messages'])
except (KeyError, TypeError): except (KeyError, TypeError):
input_tokens = 0 input_tokens = 0
# AUTH # Check user auth
received_key = incoming_request.headers.get('Authorization') received_key = incoming_request.headers.get('Authorization')
if not received_key: if not received_key:
error = await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.') return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
return error
if received_key.startswith('Bearer '): if received_key.startswith('Bearer '):
received_key = received_key.split('Bearer ')[1] received_key = received_key.split('Bearer ')[1]
# USER
user = await users.by_api_key(received_key.strip()) user = await users.by_api_key(received_key.strip())
if not user: if not user:
error = await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.') return await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
return error
ban_reason = user['status']['ban_reason'] ban_reason = user['status']['ban_reason']
if ban_reason: if ban_reason:
error = await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.') return await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
return error
if not user['status']['active']: if not user['status']['active']:
error = await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.') return await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.')
return error
if '/models' in path: if '/models' in path:
return fastapi.responses.JSONResponse(content=models_list) return fastapi.responses.JSONResponse(content=models_list)
# COST # Calculate cost of tokens & check for nsfw prompts
costs = credits_config['costs'] costs = credits_config['costs']
cost = costs['other'] cost = costs['other']
@ -94,17 +88,17 @@ async def handle(incoming_request):
policy_violation = await moderation.is_policy_violated(inp) policy_violation = await moderation.is_policy_violated(inp)
if policy_violation: if policy_violation:
error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
return error
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
cost = round(cost * role_cost_multiplier) cost = round(cost * role_cost_multiplier)
if user['credits'] < cost: if user['credits'] < cost:
error = await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.')
return error
# READY
# Send the completion request
if 'chat/completions' in path and not payload.get('stream') is True: if 'chat/completions' in path and not payload.get('stream') is True:
payload['stream'] = False payload['stream'] = False