From 007e078fb6b04b5abb9b669cd533a69fe31f84d7 Mon Sep 17 00:00:00 2001 From: monosans Date: Fri, 6 Oct 2023 10:45:50 +0300 Subject: [PATCH 1/3] Fix dangling asyncio tasks --- api/responder.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/api/responder.py b/api/responder.py index e515c84..35b7e6e 100644 --- a/api/responder.py +++ b/api/responder.py @@ -7,6 +7,7 @@ import aiohttp import asyncio import starlette +from typing import Any, Coroutine, Set from rich import print from dotenv import load_dotenv @@ -23,6 +24,19 @@ CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated'] keymanager = providerkeys.manager +background_tasks: Set[asyncio.Task[Any]] = set() + + +def create_background_task(coro: Coroutine[Any, Any, Any]) -> None: + """asyncio.create_task, which prevents the task from being garbage collected. + + https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + """ + task = asyncio.create_task(coro) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + + async def respond( path: str='/v1/chat/completions', user: dict=None, @@ -174,7 +188,7 @@ async def respond( if (not is_stream) and server_json_response: yield json.dumps(server_json_response) - asyncio.create_task( + create_background_task( after_request.after_request( incoming_request=incoming_request, target_request=target_request, From de2710539f9ed73b7e0101054a3e29d5e378c29b Mon Sep 17 00:00:00 2001 From: monosans Date: Sun, 8 Oct 2023 10:59:52 +0300 Subject: [PATCH 2/3] Add missing await --- api/helpers/tokens.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/helpers/tokens.py b/api/helpers/tokens.py index 848cac8..86fe04b 100644 --- a/api/helpers/tokens.py +++ b/api/helpers/tokens.py @@ -38,10 +38,10 @@ async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> tokens_per_name = -1 # if there's a name, the role is omitted elif 'gpt-3.5-turbo' in model: - return count_for_messages(messages, model='gpt-3.5-turbo-0613') + return await count_for_messages(messages, model='gpt-3.5-turbo-0613') elif 'gpt-4' in model: - return count_for_messages(messages, model='gpt-4-0613') + return await count_for_messages(messages, model='gpt-4-0613') else: raise NotImplementedError(f"""count_for_messages() is not implemented for model {model}. From ade7244ceae1986b3d084360b6873ab7138d6585 Mon Sep 17 00:00:00 2001 From: monosans Date: Fri, 6 Oct 2023 10:37:16 +0300 Subject: [PATCH 3/3] Refactor file operations --- api/backup_manager/main.py | 15 ++++++++------- api/core.py | 17 +++++++++++------ api/db/providerkeys.py | 8 +++++--- api/db/users.py | 2 +- api/handler.py | 5 +++-- api/providers/__main__.py | 7 +++---- api/proxies.py | 2 +- requirements.txt | 1 + rewards/main.py | 2 +- 9 files changed, 34 insertions(+), 25 deletions(-) diff --git a/api/backup_manager/main.py b/api/backup_manager/main.py index d7db852..a55f89c 100644 --- a/api/backup_manager/main.py +++ b/api/backup_manager/main.py @@ -1,6 +1,8 @@ import os import json import asyncio +import aiofiles +import aiofiles.os from sys import argv from bson import json_util @@ -18,8 +20,7 @@ async def main(output_dir: str): async def make_backup(output_dir: str): output_dir = os.path.join(FILE_DIR, '..', 'backups', output_dir) - if not os.path.exists(output_dir): - os.makedirs(output_dir) + await aiofiles.os.makedirs(output_dir, exist_ok=True) client = AsyncIOMotorClient(MONGO_URI) databases = await client.list_database_names() @@ -29,22 +30,22 @@ async def make_backup(output_dir: str): if database == 'local': continue - if not os.path.exists(f'{output_dir}/{database}'): - os.mkdir(f'{output_dir}/{database}') + await aiofiles.os.makedirs(os.path.join(output_dir, database), exist_ok=True) for collection in databases[database]: print(f'Initiated database backup for {database}/{collection}') await make_backup_for_collection(database, collection, output_dir) async def make_backup_for_collection(database, collection, output_dir): - path = f'{output_dir}/{database}/{collection}.json' + path = os.path.join(output_dir, database, f'{collection}.json') client = AsyncIOMotorClient(MONGO_URI) collection = client[database][collection] documents = await collection.find({}).to_list(length=None) - with open(path, 'w') as f: - json.dump(documents, f, default=json_util.default) + async with aiofiles.open(path, 'w') as f: + for chunk in json.JSONEncoder(default=json_util.default).iterencode(documents): + await f.write(chunk) if __name__ == '__main__': if len(argv) < 2 or len(argv) > 2: diff --git a/api/core.py b/api/core.py index f802e20..059b122 100644 --- a/api/core.py +++ b/api/core.py @@ -13,6 +13,7 @@ import json import hmac import httpx import fastapi +import aiofiles import functools from dhooks import Webhook, Embed @@ -148,11 +149,14 @@ async def run_checks(incoming_request: fastapi.Request): async def get_crypto_price(cryptocurrency: str) -> float: """Gets the price of a cryptocurrency using coinbase's API.""" - if os.path.exists('cache/crypto_prices.json'): - with open('cache/crypto_prices.json', 'r') as f: - cache = json.load(f) - else: + cache_path = os.path.join('cache', 'crypto_prices.json') + try: + async with aiofiles.open(cache_path) as f: + content = await f.read() + except FileNotFoundError: cache = {} + else: + cache = json.loads(content) is_old = time.time() - cache.get('_last_updated', 0) > 60 * 60 @@ -164,8 +168,9 @@ async def get_crypto_price(cryptocurrency: str) -> float: cache[cryptocurrency] = usd_price cache['_last_updated'] = time.time() - with open('cache/crypto_prices.json', 'w') as f: - json.dump(cache, f) + async with aiofiles.open(cache_path, 'w') as f: + for chunk in json.JSONEncoder().iterencode(cache): + await f.write(chunk) return cache[cryptocurrency] diff --git a/api/db/providerkeys.py b/api/db/providerkeys.py index 5365ba7..b202a69 100644 --- a/api/db/providerkeys.py +++ b/api/db/providerkeys.py @@ -3,6 +3,8 @@ import time import random import asyncio +import aiofiles +import aiofiles.os from aiocache import cached from dotenv import load_dotenv from cachetools import TTLCache @@ -72,10 +74,10 @@ class KeyManager: db = await self._get_collection('providerkeys') num = 0 - for filename in os.listdir('api/secret'): + for filename in await aiofiles.os.listdir(os.path.join('api', 'secret')): if filename.endswith('.txt'): - with open(f'api/secret/{filename}') as f: - for line in f.readlines(): + async with aiofiles.open(os.path.join('api', 'secret', filename)) as f: + async for line in f: if not line.strip(): continue diff --git a/api/db/users.py b/api/db/users.py index 2e325d4..3328c35 100644 --- a/api/db/users.py +++ b/api/db/users.py @@ -14,7 +14,7 @@ except ImportError: load_dotenv() -with open(helpers.root + '/api/config/config.yml', encoding='utf8') as f: +with open(os.path.join(helpers.root, 'api', 'config', 'config.yml'), encoding='utf8') as f: credits_config = yaml.safe_load(f) ## MONGODB Setup diff --git a/api/handler.py b/api/handler.py index 891ed80..36270d6 100644 --- a/api/handler.py +++ b/api/handler.py @@ -19,10 +19,11 @@ from helpers import tokens, errors, network load_dotenv() users = UserManager() -models_list = json.load(open('cache/models.json', encoding='utf8')) +with open(os.path.join('cache', 'models.json'), encoding='utf8') as f: + models_list = json.load(f) models = [model['id'] for model in models_list['data']] -with open('config/config.yml', encoding='utf8') as f: +with open(os.path.join('config', 'config.yml'), encoding='utf8') as f: config = yaml.safe_load(f) moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY') diff --git a/api/providers/__main__.py b/api/providers/__main__.py index 0fdd158..2301ec8 100644 --- a/api/providers/__main__.py +++ b/api/providers/__main__.py @@ -3,14 +3,13 @@ import sys import aiohttp import asyncio import importlib +import aiofiles.os from rich import print def remove_duplicate_keys(file): with open(file, 'r', encoding='utf8') as f: - lines = f.readlines() - - unique_lines = set(lines) + unique_lines = set(f) with open(file, 'w', encoding='utf8') as f: f.writelines(unique_lines) @@ -22,7 +21,7 @@ async def main(): except IndexError: print('List of available providers:') - for file_name in os.listdir(os.path.dirname(__file__)): + for file_name in await aiofiles.os.listdir(os.path.dirname(__file__)): if file_name.endswith('.py') and not file_name.startswith('_'): print(file_name.split('.')[0]) diff --git a/api/proxies.py b/api/proxies.py index 6b9568a..f234363 100644 --- a/api/proxies.py +++ b/api/proxies.py @@ -96,7 +96,7 @@ proxies_in_files = [] for proxy_type in ['http', 'socks4', 'socks5']: try: - with open(f'secret/proxies/{proxy_type}.txt') as f: + with open(os.path.join('secret', 'proxies', f'{proxy_type}.txt')) as f: for line in f: clean_line = line.split('#', 1)[0].strip() if clean_line: diff --git a/requirements.txt b/requirements.txt index 4b45671..d7e366c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiofiles==23.2.1 aiohttp==3.8.5 aiohttp_socks==0.8.0 dhooks==1.1.4 diff --git a/rewards/main.py b/rewards/main.py index cac3e61..710bdb0 100644 --- a/rewards/main.py +++ b/rewards/main.py @@ -51,7 +51,7 @@ async def update_roles(): def launch(): asyncio.run(main()) - with open('rewards/last_update.txt', 'w', encoding='utf8') as f: + with open(os.path.join('rewards', 'last_update.txt'), 'w', encoding='utf8') as f: f.write(str(time.time())) if __name__ == '__main__':