nova-api/api/load_balancing.py

76 lines
2 KiB
Python
Raw Normal View History

2023-08-04 03:30:56 +02:00
import random
import asyncio
import providers
2023-08-04 03:30:56 +02:00
provider_modules = [
# providers.twa,
# providers.quantum,
providers.churchless,
providers.closed,
providers.closed4
2023-08-04 03:30:56 +02:00
]
2023-08-06 21:42:07 +02:00
async def _get_module_name(module) -> str:
2023-08-04 17:29:49 +02:00
name = module.__name__
if '.' in name:
return name.split('.')[-1]
return name
2023-08-04 03:30:56 +02:00
async def balance_chat_request(payload: dict) -> dict:
2023-08-06 21:42:07 +02:00
"""Load balance the chat completion request between chat providers."""
2023-08-04 03:30:56 +02:00
providers_available = []
for provider_module in provider_modules:
if payload['stream'] and not provider_module.STREAMING:
continue
if payload['model'] not in provider_module.MODELS:
continue
providers_available.append(provider_module)
2023-08-05 02:30:42 +02:00
if not providers_available:
raise NotImplementedError('This model does not exist.')
2023-08-04 03:30:56 +02:00
provider = random.choice(providers_available)
2023-08-04 17:29:49 +02:00
target = provider.chat_completion(**payload)
2023-08-06 21:42:07 +02:00
module_name = await _get_module_name(provider)
target['module'] = module_name
2023-08-04 17:29:49 +02:00
return target
2023-08-04 03:30:56 +02:00
async def balance_organic_request(request: dict) -> dict:
2023-08-06 21:42:07 +02:00
"""Load balnace to non-chat completion request between other "organic" providers which respond in the desired format already."""
2023-08-04 03:30:56 +02:00
providers_available = []
2023-08-06 21:42:07 +02:00
if not request.get('headers'):
request['headers'] = {
'Content-Type': 'application/json'
}
2023-08-04 03:30:56 +02:00
for provider_module in provider_modules:
2023-08-06 21:42:07 +02:00
if not provider_module.ORGANIC:
continue
if '/moderations' in request['path']:
if not provider_module.MODERATIONS:
continue
providers_available.append(provider_module)
2023-08-04 03:30:56 +02:00
provider = random.choice(providers_available)
2023-08-04 17:29:49 +02:00
target = provider.organify(request)
2023-08-06 21:42:07 +02:00
module_name = await _get_module_name(provider)
target['module'] = module_name
2023-08-04 03:30:56 +02:00
2023-08-04 17:29:49 +02:00
return target
2023-08-04 03:30:56 +02:00
if __name__ == '__main__':
req = asyncio.run(balance_chat_request(payload={'model': 'gpt-3.5-turbo', 'stream': True}))
print(req['url'])