2023-08-04 03:30:56 +02:00
|
|
|
import random
|
|
|
|
import asyncio
|
2023-09-23 21:41:48 +02:00
|
|
|
from db.key_validation import cached_key_is_rated
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-06 00:43:36 +02:00
|
|
|
import providers
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
async def _get_module_name(module) -> str:
|
2023-08-04 17:29:49 +02:00
|
|
|
name = module.__name__
|
|
|
|
if '.' in name:
|
|
|
|
return name.split('.')[-1]
|
|
|
|
return name
|
|
|
|
|
2023-08-04 03:30:56 +02:00
|
|
|
async def balance_chat_request(payload: dict) -> dict:
|
2023-08-13 17:12:35 +02:00
|
|
|
"""
|
|
|
|
### Load balance the chat completion request between chat providers.
|
|
|
|
Providers are sorted by streaming and models. Target (provider.chat_completion) is returned
|
|
|
|
"""
|
2023-08-06 21:42:07 +02:00
|
|
|
|
2023-08-04 03:30:56 +02:00
|
|
|
providers_available = []
|
|
|
|
|
2023-08-07 23:28:24 +02:00
|
|
|
for provider_module in providers.MODULES:
|
2023-08-04 03:30:56 +02:00
|
|
|
if payload['stream'] and not provider_module.STREAMING:
|
|
|
|
continue
|
|
|
|
|
|
|
|
if payload['model'] not in provider_module.MODELS:
|
|
|
|
continue
|
|
|
|
|
|
|
|
providers_available.append(provider_module)
|
|
|
|
|
2023-08-05 02:30:42 +02:00
|
|
|
if not providers_available:
|
2023-09-02 21:15:55 +02:00
|
|
|
raise ValueError(f'The model "{payload["model"]}" is not available. MODEL_UNAVAILABLE')
|
2023-08-05 02:30:42 +02:00
|
|
|
|
2023-08-04 03:30:56 +02:00
|
|
|
provider = random.choice(providers_available)
|
2023-09-02 21:15:55 +02:00
|
|
|
target = await provider.chat_completion(**payload)
|
|
|
|
|
2023-09-23 21:41:48 +02:00
|
|
|
while True:
|
|
|
|
key = target.get('provider_auth')
|
|
|
|
|
|
|
|
if not await cached_key_is_rated(key):
|
|
|
|
break
|
|
|
|
|
|
|
|
else:
|
|
|
|
target = await provider.chat_completion(**payload)
|
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
module_name = await _get_module_name(provider)
|
|
|
|
target['module'] = module_name
|
2023-08-04 17:29:49 +02:00
|
|
|
|
|
|
|
return target
|
2023-08-04 03:30:56 +02:00
|
|
|
|
|
|
|
async def balance_organic_request(request: dict) -> dict:
|
2023-08-13 17:12:35 +02:00
|
|
|
"""
|
|
|
|
### Load balance non-chat completion request
|
|
|
|
Balances between other "organic" providers which respond in the desired format already.
|
|
|
|
Organic providers are used for non-chat completions, such as moderation and other paths.
|
|
|
|
"""
|
2023-08-04 03:30:56 +02:00
|
|
|
providers_available = []
|
|
|
|
|
2023-08-06 21:42:07 +02:00
|
|
|
if not request.get('headers'):
|
|
|
|
request['headers'] = {
|
|
|
|
'Content-Type': 'application/json'
|
|
|
|
}
|
|
|
|
|
2023-08-07 23:28:24 +02:00
|
|
|
for provider_module in providers.MODULES:
|
2023-08-06 21:42:07 +02:00
|
|
|
if not provider_module.ORGANIC:
|
|
|
|
continue
|
|
|
|
|
|
|
|
if '/moderations' in request['path']:
|
|
|
|
if not provider_module.MODERATIONS:
|
|
|
|
continue
|
|
|
|
|
|
|
|
providers_available.append(provider_module)
|
2023-08-04 03:30:56 +02:00
|
|
|
|
|
|
|
provider = random.choice(providers_available)
|
2023-09-02 21:15:55 +02:00
|
|
|
target = await provider.organify(request)
|
2023-08-06 21:42:07 +02:00
|
|
|
|
|
|
|
module_name = await _get_module_name(provider)
|
|
|
|
target['module'] = module_name
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-04 17:29:49 +02:00
|
|
|
return target
|
2023-08-04 03:30:56 +02:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
req = asyncio.run(balance_chat_request(payload={'model': 'gpt-3.5-turbo', 'stream': True}))
|
|
|
|
print(req['url'])
|