Compare commits

..

6 commits

Author SHA1 Message Date
Game_Time bb1e9de563 Update streaming to work & change stats to class 2023-08-13 21:42:38 +05:00
Game_Time 885da2a27e
add comments to streaming.py 2023-08-13 21:29:45 +05:00
Game_Time 3e811f3e3b
massive cleanup of streaming (i think this works?) 2023-08-13 21:26:35 +05:00
Game_Time def26f9104
clean up code a lot 2023-08-13 21:19:56 +05:00
Game_Time 8e70c25ee0
updating tests to add tests back since they were removed for some reason? 2023-08-13 20:16:33 +05:00
Game_Time 6ecc5f59ce Codebase changes + a lot of commenting 2023-08-13 20:12:35 +05:00
9 changed files with 175 additions and 190 deletions

View file

@ -24,11 +24,24 @@ async def _get_collection(collection_name: str):
return conn['nova-core'][collection_name] return conn['nova-core'][collection_name]
async def replacer(text: str, dict_: dict) -> str: async def replacer(text: str, dict_: dict) -> str:
# This seems to exist for a very specific and dumb purpose :D
for k, v in dict_.items(): for k, v in dict_.items():
text = text.replace(k, v) text = text.replace(k, v)
return text return text
async def log_api_request(user: dict, incoming_request, target_url: str): async def log_api_request(user: dict, incoming_request, target_url: str):
"""Logs the API Request into the database.
No input prompt is logged, however data such as IP & useragent is noted.
This would be useful for security reasons. Other minor data is also collected.
Args:
user (dict): User dict object
incoming_request (_type_): Request
target_url (str): The URL the api request was targetted to.
Returns:
_type_: _description_
"""
db = await _get_collection('logs') db = await _get_collection('logs')
payload = {} payload = {}

View file

@ -17,39 +17,50 @@ async def _get_collection(collection_name: str):
## Statistics ## Statistics
async def add_date(): class Stats:
"""
### The manager for all statistics tracking
Stats tracked:
- Dates
- IPs
- Target URLs
- Tokens
- Models
- URL Paths
"""
async def add_date():
date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d') date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d')
year, month, day = date.split('.') year, month, day = date.split('.')
db = await _get_collection('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True)
async def add_ip_address(ip_address: str): async def add_ip_address(ip_address: str):
ip_address = ip_address.replace('.', '_') ip_address = ip_address.replace('.', '_')
db = await _get_collection('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True)
async def add_target(url: str): async def add_target(url: str):
db = await _get_collection('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True)
async def add_tokens(tokens: int, model: str): async def add_tokens(tokens: int, model: str):
db = await _get_collection('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True) await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True)
async def add_model(model: str): async def add_model(model: str):
db = await _get_collection('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True)
async def add_path(path: str): async def add_path(path: str):
path = path.replace('/', '_') path = path.replace('/', '_')
db = await _get_collection('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True)
async def get_value(obj_filter): async def get_value(obj_filter):
db = await _get_collection('stats') db = await _get_collection('stats')
return await db.find_one({obj_filter}) return await db.find_one({obj_filter})
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(add_date()) asyncio.run(Stats.add_date())
asyncio.run(add_path('/__demo/test')) asyncio.run(Stats.add_path('/__demo/test'))

View file

@ -20,7 +20,14 @@ async def _get_collection(collection_name: str):
return conn['nova-core'][collection_name] return conn['nova-core'][collection_name]
async def create(discord_id: str='') -> dict: async def create(discord_id: str='') -> dict:
"""Adds a new user to the MongoDB collection.""" """Add a user to the mongodb
Args:
discord_id (str): Defaults to ''.
Returns:
dict: The user object
"""
chars = string.ascii_letters + string.digits chars = string.ascii_letters + string.digits

View file

@ -10,8 +10,10 @@ async def _get_module_name(module) -> str:
return name return name
async def balance_chat_request(payload: dict) -> dict: async def balance_chat_request(payload: dict) -> dict:
"""Load balance the chat completion request between chat providers. """
""" ### Load balance the chat completion request between chat providers.
Providers are sorted by streaming and models. Target (provider.chat_completion) is returned
"""
providers_available = [] providers_available = []
@ -36,9 +38,11 @@ async def balance_chat_request(payload: dict) -> dict:
return target return target
async def balance_organic_request(request: dict) -> dict: async def balance_organic_request(request: dict) -> dict:
"""Load balnace to non-chat completion request between other "organic" providers which respond in the desired format already. """
Organic providers are used for non-chat completions, such as moderation and other paths. ### Load balance non-chat completion request
""" Balances between other "organic" providers which respond in the desired format already.
Organic providers are used for non-chat completions, such as moderation and other paths.
"""
providers_available = [] providers_available = []
if not request.get('headers'): if not request.get('headers'):

View file

@ -31,7 +31,9 @@ async def startup_event():
@app.get('/') @app.get('/')
async def root(): async def root():
"""Returns the root endpoint.""" """
Returns the root endpoint.
"""
return { return {
'status': 'ok', 'status': 'ok',

View file

@ -63,7 +63,11 @@ class Proxy:
@property @property
def connector(self): def connector(self):
"""Returns an aiohttp_socks.ProxyConnector object. Which can be used in aiohttp.ClientSession.""" """
### Returns a proxy connector
Returns an aiohttp_socks.ProxyConnector object.
This can be used in aiohttp.ClientSession.
"""
proxy_types = { proxy_types = {
'http': aiohttp_socks.ProxyType.HTTP, 'http': aiohttp_socks.ProxyType.HTTP,

View file

@ -15,7 +15,8 @@ import proxies
import provider_auth import provider_auth
import load_balancing import load_balancing
from db import logs, users, stats from db import logs, users
from db.stats import Stats
from helpers import network, chat, errors from helpers import network, chat, errors
load_dotenv() load_dotenv()
@ -30,6 +31,33 @@ DEMO_PAYLOAD = {
] ]
} }
async def process_response(response, is_chat, chat_id, model, target_request):
"""Proccesses chunks from streaming
Args:
response (_type_): The response
is_chat (bool): If there is 'chat/completions' in path
chat_id (_type_): ID of chat with bot
model (_type_): What AI model it is
"""
async for chunk in response.content.iter_any():
chunk = chunk.decode("utf8").strip()
send = False
if is_chat and '{' in chunk:
data = json.loads(chunk.split('data: ')[1])
chunk = chunk.replace(data['id'], chat_id)
send = True
if target_request['module'] == 'twa' and data.get('text'):
chunk = await chat.create_chat_chunk(chat_id=chat_id, model=model, content=['text'])
if (not data['choices'][0]['delta']) or data['choices'][0]['delta'] == {'role': 'assistant'}:
send = False
if send and chunk:
yield chunk + '\n\n'
async def stream( async def stream(
path: str='/v1/chat/completions', path: str='/v1/chat/completions',
user: dict=None, user: dict=None,
@ -38,6 +66,17 @@ async def stream(
input_tokens: int=0, input_tokens: int=0,
incoming_request: starlette.requests.Request=None, incoming_request: starlette.requests.Request=None,
): ):
"""Stream the completions request. Sends data in chunks
Args:
path (str, optional): URL Path. Defaults to '/v1/chat/completions'.
user (dict, optional): User object (dict) Defaults to None.
payload (dict, optional): Payload. Defaults to None.
credits_cost (int, optional): Cost of the credits of the request. Defaults to 0.
input_tokens (int, optional): Total tokens calculated with tokenizer. Defaults to 0.
incoming_request (starlette.requests.Request, optional): Incoming request. Defaults to None.
"""
is_chat = False is_chat = False
is_stream = payload.get('stream', False) is_stream = payload.get('stream', False)
@ -45,34 +84,16 @@ async def stream(
is_chat = True is_chat = True
model = payload['model'] model = payload['model']
# Chat completions always have the same beginning
if is_chat and is_stream: if is_chat and is_stream:
chat_id = await chat.create_chat_id() chat_id = await chat.create_chat_id()
yield await chat.create_chat_chunk(chat_id=chat_id, model=model, content=chat.CompletionStart)
yield await chat.create_chat_chunk(chat_id=chat_id, model=model, content=None)
chunk = await chat.create_chat_chunk( json_response = {'error': 'No JSON response could be received'}
chat_id=chat_id,
model=model,
content=chat.CompletionStart
)
yield chunk
chunk = await chat.create_chat_chunk(
chat_id=chat_id,
model=model,
content=None
)
yield chunk
json_response = {
'error': 'No JSON response could be received'
}
# Try to get a response from the API
for _ in range(5): for _ in range(5):
headers = { headers = {'Content-Type': 'application/json'}
'Content-Type': 'application/json'
}
# Load balancing # Load balancing
# If the request is a chat completion, then we need to load balance between chat providers # If the request is a chat completion, then we need to load balance between chat providers
@ -82,7 +103,8 @@ async def stream(
if is_chat: if is_chat:
target_request = await load_balancing.balance_chat_request(payload) target_request = await load_balancing.balance_chat_request(payload)
else: else:
# "organic" means that it's not using a reverse engineered front-end, but rather ClosedAI's API directly
# In this case we are doing a organic request. "organic" means that it's not using a reverse engineered front-end, but rather ClosedAI's API directly
# churchless.tech is an example of an organic provider, because it redirects the request to ClosedAI. # churchless.tech is an example of an organic provider, because it redirects the request to ClosedAI.
target_request = await load_balancing.balance_organic_request({ target_request = await load_balancing.balance_organic_request({
'method': incoming_request.method, 'method': incoming_request.method,
@ -92,19 +114,12 @@ async def stream(
'cookies': incoming_request.cookies 'cookies': incoming_request.cookies
}) })
except ValueError as exc: except ValueError as exc:
# Error load balancing? Send a webhook to the admins
webhook = dhooks.Webhook(os.getenv('DISCORD_WEBHOOK__API_ISSUE')) webhook = dhooks.Webhook(os.getenv('DISCORD_WEBHOOK__API_ISSUE'))
webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg') webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg')
yield await errors.yield_error(500, 'Sorry, the API has no working keys anymore.', 'The admins have been messaged automatically.')
yield await errors.yield_error(
500,
'Sorry, the API has no working keys anymore.',
'The admins have been messaged automatically.'
)
return return
for k, v in target_request.get('headers', {}).items(): target_request['headers'].update(target_request.get('headers', {}))
target_request['headers'][k] = v
if target_request['method'] == 'GET' and not payload: if target_request['method'] == 'GET' and not payload:
target_request['payload'] = None target_request['payload'] = None
@ -116,22 +131,16 @@ async def stream(
async with session.request( async with session.request(
method=target_request.get('method', 'POST'), method=target_request.get('method', 'POST'),
url=target_request['url'], url=target_request['url'],
data=target_request.get('data'), data=target_request.get('data'),
json=target_request.get('payload'), json=target_request.get('payload'),
headers=target_request.get('headers', {}), headers=target_request.get('headers', {}),
cookies=target_request.get('cookies'), cookies=target_request.get('cookies'),
ssl=False, ssl=False,
timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))), timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))),
) as response: ) as response:
# if the answer is JSON
if response.content_type == 'application/json': if response.content_type == 'application/json':
data = await response.json() data = await response.json()
# Invalidate the key if it's not working
if data.get('code') == 'invalid_api_key': if data.get('code') == 'invalid_api_key':
await provider_auth.invalidate_key(target_request.get('provider_auth')) await provider_auth.invalidate_key(target_request.get('provider_auth'))
continue continue
@ -139,52 +148,15 @@ async def stream(
if response.ok: if response.ok:
json_response = data json_response = data
# if the answer is a stream
if is_stream: if is_stream:
try: try:
response.raise_for_status() response.raise_for_status()
except Exception as exc: except Exception as exc:
# Rate limit? Balance again
if 'Too Many Requests' in str(exc): if 'Too Many Requests' in str(exc):
continue continue
try: async for chunk in process_response(response, is_chat, chat_id, model, target_request):
# process the response chunks yield chunk
async for chunk in response.content.iter_any():
send = False
chunk = f'{chunk.decode("utf8")}\n\n'
if is_chat and '{' in chunk:
# parse the JSON
data = json.loads(chunk.split('data: ')[1])
chunk = chunk.replace(data['id'], chat_id)
send = True
# create a custom chunk if we're using specific providers
if target_request['module'] == 'twa' and data.get('text'):
chunk = await chat.create_chat_chunk(
chat_id=chat_id,
model=model,
content=['text']
)
# don't send empty/unnecessary messages
if (not data['choices'][0]['delta']) or data['choices'][0]['delta'] == {'role': 'assistant'}:
send = False
# send the chunk
if send and chunk.strip():
final_chunk = chunk.strip().replace('data: [DONE]', '') + '\n\n'
yield final_chunk
except Exception as exc:
if 'Connection closed' in str(exc):
yield await errors.yield_error(
500,
'Sorry, there was an issue with the connection.',
'Please first check if the issue on your end. If this error repeats, please don\'t heistate to contact the staff!.'
)
return
break break
@ -192,44 +164,27 @@ async def stream(
print('[!] Proxy error:', exc) print('[!] Proxy error:', exc)
continue continue
# Chat completions always have the same ending
if is_chat and is_stream: if is_chat and is_stream:
chunk = await chat.create_chat_chunk( yield await chat.create_chat_chunk(chat_id=chat_id, model=model, content=chat.CompletionStop)
chat_id=chat_id,
model=model,
content=chat.CompletionStop
)
yield chunk
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
# If the response is JSON, then we need to yield it like this
if not is_stream and json_response: if not is_stream and json_response:
yield json.dumps(json_response) yield json.dumps(json_response)
# DONE WITH REQUEST, NOW LOGGING ETC.
if user and incoming_request: if user and incoming_request:
await logs.log_api_request( await logs.log_api_request(user=user, incoming_request=incoming_request, target_url=target_request['url'])
user=user,
incoming_request=incoming_request,
target_url=target_request['url']
)
if credits_cost and user: if credits_cost and user:
await users.update_by_id(user['_id'], { await users.update_by_id(user['_id'], {'$inc': {'credits': -credits_cost}})
'$inc': {'credits': -credits_cost}
})
ip_address = await network.get_ip(incoming_request) ip_address = await network.get_ip(incoming_request)
await Stats.add_date()
await stats.add_date() await Stats.add_ip_address(ip_address)
await stats.add_ip_address(ip_address) await Stats.add_path(path)
await stats.add_path(path) await Stats.add_target(target_request['url'])
await stats.add_target(target_request['url'])
if is_chat: if is_chat:
await stats.add_model(model) await Stats.add_model(model)
await stats.add_tokens(input_tokens, model) await Stats.add_tokens(input_tokens, model)
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(stream()) asyncio.run(stream())

View file

@ -20,87 +20,67 @@ with open('config/credits.yml', encoding='utf8') as f:
credits_config = yaml.safe_load(f) credits_config = yaml.safe_load(f)
async def handle(incoming_request): async def handle(incoming_request):
"""Transfer a streaming response from the incoming request to the target endpoint""" """
### Transfer a streaming response
Takes the request from the incoming request to the target endpoint.
Checks method, token amount, auth and cost along with if request is NSFW.
"""
path = incoming_request.url.path.replace('v1/v1/', 'v1/') path = incoming_request.url.path.replace('v1/v1/', 'v1/')
# METHOD allowed_methods = {'GET', 'POST', 'PUT', 'DELETE', 'PATCH'}
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']: method = incoming_request.method
return await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
if method not in allowed_methods:
return await errors.error(405, f'Method "{method}" is not allowed.', 'Change the request method to the correct one.')
# PAYLOAD
try:
payload = await incoming_request.json() payload = await incoming_request.json()
except json.decoder.JSONDecodeError:
payload = {}
# Tokenise w/ tiktoken
try: try:
input_tokens = await tokens.count_for_messages(payload['messages']) input_tokens = await tokens.count_for_messages(payload.get('messages', []))
except (KeyError, TypeError): except (KeyError, TypeError):
input_tokens = 0 input_tokens = 0
# Check user auth
received_key = incoming_request.headers.get('Authorization') received_key = incoming_request.headers.get('Authorization')
if not received_key: if not received_key or not received_key.startswith('Bearer '):
return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.') return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
if received_key.startswith('Bearer '): user = await users.by_api_key(received_key.split('Bearer ')[1].strip())
received_key = received_key.split('Bearer ')[1]
user = await users.by_api_key(received_key.strip()) if not user or not user['status']['active']:
return await errors.error(401, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.')
if not user:
return await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
ban_reason = user['status']['ban_reason'] ban_reason = user['status']['ban_reason']
if ban_reason: if ban_reason:
return await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.') return await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
if not user['status']['active']: path_contains_models = '/models' in path
return await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.') if path_contains_models:
if '/models' in path:
return fastapi.responses.JSONResponse(content=models_list) return fastapi.responses.JSONResponse(content=models_list)
# Calculate cost of tokens & check for nsfw prompts
costs = credits_config['costs'] costs = credits_config['costs']
cost = costs['other'] cost = costs['other']
policy_violation = False
if 'chat/completions' in path: if 'chat/completions' in path:
for model_name, model_cost in costs['chat-models'].items(): cost = costs['chat-models'].get(payload.get('model'), cost)
if model_name in payload['model']:
cost = model_cost
policy_violation = await moderation.is_policy_violated(payload['messages']) policy_violation = False
if 'chat/completions' in path or ('input' in payload or 'prompt' in payload):
elif '/moderations' in path: inp = payload.get('input', payload.get('prompt', ''))
pass if inp and len(inp) > 2 and not inp.isnumeric():
else:
inp = payload.get('input', payload.get('prompt'))
if inp:
if len(inp) > 2 and not inp.isnumeric():
policy_violation = await moderation.is_policy_violated(inp) policy_violation = await moderation.is_policy_violated(inp)
if policy_violation: if policy_violation:
return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
cost = round(cost * role_cost_multiplier) cost = round(cost * role_cost_multiplier)
if user['credits'] < cost: if user['credits'] < cost:
return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.')
if 'chat/completions' in path and not payload.get('stream', False):
# Send the completion request
if 'chat/completions' in path and not payload.get('stream') is True:
payload['stream'] = False payload['stream'] = False
media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json' media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json'

View file

@ -83,11 +83,20 @@ def test_models():
def test_all(): def test_all():
"""Runs all tests.""" """Runs all tests."""
# print(test_server()) print("Running test on API server to check if its running.."
# print(test_api()) print(test_server())
print("Running a api endpoint to see if requests can go through..."
print(test_api())
print("Checking if the API works with the python library..."
print(test_library()) print(test_library())
# print(test_library_moderation())
# print(test_models()) print("Checking if the moderation endpoint works...")
print(test_library_moderation())
print("Checking if all models can be GET"
print(test_models())
def test_api_moderation(model: str=MODEL, messages: List[dict]=None) -> dict: def test_api_moderation(model: str=MODEL, messages: List[dict]=None) -> dict:
"""Tests an API api_endpoint.""" """Tests an API api_endpoint."""