Compare commits

...

6 commits

Author SHA1 Message Date
monosans ade7244cea
Refactor file operations 2023-10-09 09:02:09 +03:00
nsde 003a7d3d71
Merge pull request #18 from monosans/patch-2
Fix dangling asyncio tasks
2023-10-09 00:06:45 +02:00
nsde ad9f442fa1
Merge pull request #19 from monosans/patch-3
Add missing await
2023-10-09 00:06:14 +02:00
nsde 23a904f3ce Added buffering, fixing a common chunk yielding issue 2023-10-08 23:56:32 +02:00
monosans de2710539f
Add missing await 2023-10-08 23:05:11 +03:00
monosans 007e078fb6
Fix dangling asyncio tasks 2023-10-08 23:05:03 +03:00
12 changed files with 70 additions and 39 deletions

View file

@ -1,6 +1,8 @@
import os import os
import json import json
import asyncio import asyncio
import aiofiles
import aiofiles.os
from sys import argv from sys import argv
from bson import json_util from bson import json_util
@ -18,8 +20,7 @@ async def main(output_dir: str):
async def make_backup(output_dir: str): async def make_backup(output_dir: str):
output_dir = os.path.join(FILE_DIR, '..', 'backups', output_dir) output_dir = os.path.join(FILE_DIR, '..', 'backups', output_dir)
if not os.path.exists(output_dir): await aiofiles.os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir)
client = AsyncIOMotorClient(MONGO_URI) client = AsyncIOMotorClient(MONGO_URI)
databases = await client.list_database_names() databases = await client.list_database_names()
@ -29,22 +30,22 @@ async def make_backup(output_dir: str):
if database == 'local': if database == 'local':
continue continue
if not os.path.exists(f'{output_dir}/{database}'): await aiofiles.os.makedirs(os.path.join(output_dir, database), exist_ok=True)
os.mkdir(f'{output_dir}/{database}')
for collection in databases[database]: for collection in databases[database]:
print(f'Initiated database backup for {database}/{collection}') print(f'Initiated database backup for {database}/{collection}')
await make_backup_for_collection(database, collection, output_dir) await make_backup_for_collection(database, collection, output_dir)
async def make_backup_for_collection(database, collection, output_dir): async def make_backup_for_collection(database, collection, output_dir):
path = f'{output_dir}/{database}/{collection}.json' path = os.path.join(output_dir, database, f'{collection}.json')
client = AsyncIOMotorClient(MONGO_URI) client = AsyncIOMotorClient(MONGO_URI)
collection = client[database][collection] collection = client[database][collection]
documents = await collection.find({}).to_list(length=None) documents = await collection.find({}).to_list(length=None)
with open(path, 'w') as f: async with aiofiles.open(path, 'w') as f:
json.dump(documents, f, default=json_util.default) for chunk in json.JSONEncoder(default=json_util.default).iterencode(documents):
await f.write(chunk)
if __name__ == '__main__': if __name__ == '__main__':
if len(argv) < 2 or len(argv) > 2: if len(argv) < 2 or len(argv) > 2:

View file

@ -13,6 +13,7 @@ import json
import hmac import hmac
import httpx import httpx
import fastapi import fastapi
import aiofiles
import functools import functools
from dhooks import Webhook, Embed from dhooks import Webhook, Embed
@ -148,11 +149,14 @@ async def run_checks(incoming_request: fastapi.Request):
async def get_crypto_price(cryptocurrency: str) -> float: async def get_crypto_price(cryptocurrency: str) -> float:
"""Gets the price of a cryptocurrency using coinbase's API.""" """Gets the price of a cryptocurrency using coinbase's API."""
if os.path.exists('cache/crypto_prices.json'): cache_path = os.path.join('cache', 'crypto_prices.json')
with open('cache/crypto_prices.json', 'r') as f: try:
cache = json.load(f) async with aiofiles.open(cache_path) as f:
else: content = await f.read()
except FileNotFoundError:
cache = {} cache = {}
else:
cache = json.loads(content)
is_old = time.time() - cache.get('_last_updated', 0) > 60 * 60 is_old = time.time() - cache.get('_last_updated', 0) > 60 * 60
@ -164,8 +168,9 @@ async def get_crypto_price(cryptocurrency: str) -> float:
cache[cryptocurrency] = usd_price cache[cryptocurrency] = usd_price
cache['_last_updated'] = time.time() cache['_last_updated'] = time.time()
with open('cache/crypto_prices.json', 'w') as f: async with aiofiles.open(cache_path, 'w') as f:
json.dump(cache, f) for chunk in json.JSONEncoder().iterencode(cache):
await f.write(chunk)
return cache[cryptocurrency] return cache[cryptocurrency]

View file

@ -3,6 +3,8 @@ import time
import random import random
import asyncio import asyncio
import aiofiles
import aiofiles.os
from aiocache import cached from aiocache import cached
from dotenv import load_dotenv from dotenv import load_dotenv
from cachetools import TTLCache from cachetools import TTLCache
@ -72,10 +74,10 @@ class KeyManager:
db = await self._get_collection('providerkeys') db = await self._get_collection('providerkeys')
num = 0 num = 0
for filename in os.listdir('api/secret'): for filename in await aiofiles.os.listdir(os.path.join('api', 'secret')):
if filename.endswith('.txt'): if filename.endswith('.txt'):
with open(f'api/secret/{filename}') as f: async with aiofiles.open(os.path.join('api', 'secret', filename)) as f:
for line in f.readlines(): async for line in f:
if not line.strip(): if not line.strip():
continue continue

View file

@ -14,7 +14,7 @@ except ImportError:
load_dotenv() load_dotenv()
with open(helpers.root + '/api/config/config.yml', encoding='utf8') as f: with open(os.path.join(helpers.root, 'api', 'config', 'config.yml'), encoding='utf8') as f:
credits_config = yaml.safe_load(f) credits_config = yaml.safe_load(f)
## MONGODB Setup ## MONGODB Setup

View file

@ -19,10 +19,11 @@ from helpers import tokens, errors, network
load_dotenv() load_dotenv()
users = UserManager() users = UserManager()
models_list = json.load(open('cache/models.json', encoding='utf8')) with open(os.path.join('cache', 'models.json'), encoding='utf8') as f:
models_list = json.load(f)
models = [model['id'] for model in models_list['data']] models = [model['id'] for model in models_list['data']]
with open('config/config.yml', encoding='utf8') as f: with open(os.path.join('config', 'config.yml'), encoding='utf8') as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY') moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY')

View file

@ -38,10 +38,10 @@ async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') ->
tokens_per_name = -1 # if there's a name, the role is omitted tokens_per_name = -1 # if there's a name, the role is omitted
elif 'gpt-3.5-turbo' in model: elif 'gpt-3.5-turbo' in model:
return count_for_messages(messages, model='gpt-3.5-turbo-0613') return await count_for_messages(messages, model='gpt-3.5-turbo-0613')
elif 'gpt-4' in model: elif 'gpt-4' in model:
return count_for_messages(messages, model='gpt-4-0613') return await count_for_messages(messages, model='gpt-4-0613')
else: else:
raise NotImplementedError(f"""count_for_messages() is not implemented for model {model}. raise NotImplementedError(f"""count_for_messages() is not implemented for model {model}.

View file

@ -3,14 +3,13 @@ import sys
import aiohttp import aiohttp
import asyncio import asyncio
import importlib import importlib
import aiofiles.os
from rich import print from rich import print
def remove_duplicate_keys(file): def remove_duplicate_keys(file):
with open(file, 'r', encoding='utf8') as f: with open(file, 'r', encoding='utf8') as f:
lines = f.readlines() unique_lines = set(f)
unique_lines = set(lines)
with open(file, 'w', encoding='utf8') as f: with open(file, 'w', encoding='utf8') as f:
f.writelines(unique_lines) f.writelines(unique_lines)
@ -22,7 +21,7 @@ async def main():
except IndexError: except IndexError:
print('List of available providers:') print('List of available providers:')
for file_name in os.listdir(os.path.dirname(__file__)): for file_name in await aiofiles.os.listdir(os.path.dirname(__file__)):
if file_name.endswith('.py') and not file_name.startswith('_'): if file_name.endswith('.py') and not file_name.startswith('_'):
print(file_name.split('.')[0]) print(file_name.split('.')[0])

View file

@ -14,7 +14,7 @@ MODELS = [
] ]
# MODELS = [f'{model}-azure' for model in MODELS] # MODELS = [f'{model}-azure' for model in MODELS]
AZURE_API = '2023-07-01-preview' AZURE_API = '2023-08-01-preview'
async def chat_completion(**payload): async def chat_completion(**payload):
key = await utils.random_secret_for('azure-nva1') key = await utils.random_secret_for('azure-nva1')

View file

@ -96,7 +96,7 @@ proxies_in_files = []
for proxy_type in ['http', 'socks4', 'socks5']: for proxy_type in ['http', 'socks4', 'socks5']:
try: try:
with open(f'secret/proxies/{proxy_type}.txt') as f: with open(os.path.join('secret', 'proxies', f'{proxy_type}.txt')) as f:
for line in f: for line in f:
clean_line = line.split('#', 1)[0].strip() clean_line = line.split('#', 1)[0].strip()
if clean_line: if clean_line:

View file

@ -7,6 +7,7 @@ import aiohttp
import asyncio import asyncio
import starlette import starlette
from typing import Any, Coroutine, Set
from rich import print from rich import print
from dotenv import load_dotenv from dotenv import load_dotenv
@ -23,6 +24,19 @@ CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated']
keymanager = providerkeys.manager keymanager = providerkeys.manager
background_tasks: Set[asyncio.Task[Any]] = set()
def create_background_task(coro: Coroutine[Any, Any, Any]) -> None:
"""asyncio.create_task, which prevents the task from being garbage collected.
https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
"""
task = asyncio.create_task(coro)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
async def respond( async def respond(
path: str='/v1/chat/completions', path: str='/v1/chat/completions',
user: dict=None, user: dict=None,
@ -49,7 +63,7 @@ async def respond(
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
for i in range(1): for i in range(5):
try: try:
if is_chat: if is_chat:
target_request = await load_balancing.balance_chat_request(payload) target_request = await load_balancing.balance_chat_request(payload)
@ -96,7 +110,7 @@ async def respond(
cookies=target_request.get('cookies'), cookies=target_request.get('cookies'),
ssl=False, ssl=False,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
connect=1.0, connect=0.75,
total=float(os.getenv('TRANSFER_TIMEOUT', '500')) total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
) )
) as response: ) as response:
@ -148,24 +162,32 @@ async def respond(
continue continue
chunk_no = 0 chunk_no = 0
async for chunk in response.content.iter_any(): buffer = ''
async for chunk in response.content.iter_chunked(1024):
chunk_no += 1 chunk_no += 1
chunk = chunk.decode('utf8').strip()
chunk = chunk.decode('utf8')
if 'azure' in provider_name: if 'azure' in provider_name:
chunk = chunk.strip().replace('data: ', '') chunk = chunk.replace('data: ', '')
if not chunk or chunk_no == 1: if not chunk or chunk_no == 1:
continue continue
yield chunk + '\n\n' subchunks = chunk.split('\n\n')
buffer += subchunks[0]
yield buffer + '\n\n'
buffer = subchunks[-1]
for subchunk in subchunks[1:-1]:
yield subchunk + '\n\n'
break break
except Exception as exc: except aiohttp.client_exceptions.ServerTimeoutError:
print('[!] exception', exc) continue
# continue
raise exc
else: else:
yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.') yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.')
@ -174,7 +196,7 @@ async def respond(
if (not is_stream) and server_json_response: if (not is_stream) and server_json_response:
yield json.dumps(server_json_response) yield json.dumps(server_json_response)
asyncio.create_task( create_background_task(
after_request.after_request( after_request.after_request(
incoming_request=incoming_request, incoming_request=incoming_request,
target_request=target_request, target_request=target_request,

View file

@ -1,3 +1,4 @@
aiofiles==23.2.1
aiohttp==3.8.5 aiohttp==3.8.5
aiohttp_socks==0.8.0 aiohttp_socks==0.8.0
dhooks==1.1.4 dhooks==1.1.4

View file

@ -51,7 +51,7 @@ async def update_roles():
def launch(): def launch():
asyncio.run(main()) asyncio.run(main())
with open('rewards/last_update.txt', 'w', encoding='utf8') as f: with open(os.path.join('rewards', 'last_update.txt'), 'w', encoding='utf8') as f:
f.write(str(time.time())) f.write(str(time.time()))
if __name__ == '__main__': if __name__ == '__main__':