Compare commits

..

7 commits

Author SHA1 Message Date
monosans 258c5c6cfc
Merge ade7244cea into 003a7d3d71 2023-10-09 06:02:16 +00:00
monosans ade7244cea
Refactor file operations 2023-10-09 09:02:09 +03:00
nsde 003a7d3d71
Merge pull request #18 from monosans/patch-2
Fix dangling asyncio tasks
2023-10-09 00:06:45 +02:00
nsde ad9f442fa1
Merge pull request #19 from monosans/patch-3
Add missing await
2023-10-09 00:06:14 +02:00
nsde 23a904f3ce Added buffering, fixing a common chunk yielding issue 2023-10-08 23:56:32 +02:00
monosans de2710539f
Add missing await 2023-10-08 23:05:11 +03:00
monosans 007e078fb6
Fix dangling asyncio tasks 2023-10-08 23:05:03 +03:00
3 changed files with 36 additions and 14 deletions

View file

@ -38,10 +38,10 @@ async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') ->
tokens_per_name = -1 # if there's a name, the role is omitted
elif 'gpt-3.5-turbo' in model:
return count_for_messages(messages, model='gpt-3.5-turbo-0613')
return await count_for_messages(messages, model='gpt-3.5-turbo-0613')
elif 'gpt-4' in model:
return count_for_messages(messages, model='gpt-4-0613')
return await count_for_messages(messages, model='gpt-4-0613')
else:
raise NotImplementedError(f"""count_for_messages() is not implemented for model {model}.

View file

@ -14,7 +14,7 @@ MODELS = [
]
# MODELS = [f'{model}-azure' for model in MODELS]
AZURE_API = '2023-07-01-preview'
AZURE_API = '2023-08-01-preview'
async def chat_completion(**payload):
key = await utils.random_secret_for('azure-nva1')

View file

@ -7,6 +7,7 @@ import aiohttp
import asyncio
import starlette
from typing import Any, Coroutine, Set
from rich import print
from dotenv import load_dotenv
@ -23,6 +24,19 @@ CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated']
keymanager = providerkeys.manager
background_tasks: Set[asyncio.Task[Any]] = set()
def create_background_task(coro: Coroutine[Any, Any, Any]) -> None:
"""asyncio.create_task, which prevents the task from being garbage collected.
https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
"""
task = asyncio.create_task(coro)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
async def respond(
path: str='/v1/chat/completions',
user: dict=None,
@ -49,7 +63,7 @@ async def respond(
'Content-Type': 'application/json'
}
for i in range(1):
for i in range(5):
try:
if is_chat:
target_request = await load_balancing.balance_chat_request(payload)
@ -96,7 +110,7 @@ async def respond(
cookies=target_request.get('cookies'),
ssl=False,
timeout=aiohttp.ClientTimeout(
connect=1.0,
connect=0.75,
total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
)
) as response:
@ -148,24 +162,32 @@ async def respond(
continue
chunk_no = 0
async for chunk in response.content.iter_any():
buffer = ''
async for chunk in response.content.iter_chunked(1024):
chunk_no += 1
chunk = chunk.decode('utf8').strip()
chunk = chunk.decode('utf8')
if 'azure' in provider_name:
chunk = chunk.strip().replace('data: ', '')
chunk = chunk.replace('data: ', '')
if not chunk or chunk_no == 1:
continue
yield chunk + '\n\n'
subchunks = chunk.split('\n\n')
buffer += subchunks[0]
yield buffer + '\n\n'
buffer = subchunks[-1]
for subchunk in subchunks[1:-1]:
yield subchunk + '\n\n'
break
except Exception as exc:
print('[!] exception', exc)
# continue
raise exc
except aiohttp.client_exceptions.ServerTimeoutError:
continue
else:
yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.')
@ -174,7 +196,7 @@ async def respond(
if (not is_stream) and server_json_response:
yield json.dumps(server_json_response)
asyncio.create_task(
create_background_task(
after_request.after_request(
incoming_request=incoming_request,
target_request=target_request,