Improved errors, checking and fixed ratelimit retrying

This commit is contained in:
nsde 2023-10-05 14:17:53 +02:00
parent 01aa41b6b1
commit 6ef8441681
5 changed files with 37 additions and 36 deletions

View file

@ -12,7 +12,7 @@ class KeyManager:
self.conn = AsyncIOMotorClient(os.environ['MONGO_URI']) self.conn = AsyncIOMotorClient(os.environ['MONGO_URI'])
async def _get_collection(self, collection_name: str): async def _get_collection(self, collection_name: str):
return self.conn[os.getenv('MONGO_NAME', 'nova-test')][collection_name] return self.conn['nova-core'][collection_name]
async def add_key(self, provider: str, key: str, source: str='?'): async def add_key(self, provider: str, key: str, source: str='?'):
db = await self._get_collection('providerkeys') db = await self._get_collection('providerkeys')
@ -36,7 +36,7 @@ class KeyManager:
}) })
if key is None: if key is None:
return ValueError('No keys available for this provider!') return '--NO_KEY--'
return key['key'] return key['key']
@ -87,4 +87,4 @@ class KeyManager:
manager = KeyManager() manager = KeyManager()
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(manager.delete_empty_keys()) asyncio.run(manager.import_all())

View file

@ -8,8 +8,7 @@ async def error(code: int, message: str, tip: str) -> starlette.responses.Respon
'code': code, 'code': code,
'message': message, 'message': message,
'tip': tip, 'tip': tip,
'website': 'https://nova-oss.com', 'powered_by': 'nova-api'
'by': 'NovaOSS/Nova-API'
}} }}
return starlette.responses.Response(status_code=code, content=json.dumps(info)) return starlette.responses.Response(status_code=code, content=json.dumps(info))
@ -20,5 +19,6 @@ async def yield_error(code: int, message: str, tip: str) -> str:
return json.dumps({ return json.dumps({
'code': code, 'code': code,
'message': message, 'message': message,
'tip': tip 'tip': tip,
'powered_by': 'nova-api'
}) })

View file

@ -35,9 +35,9 @@ limiter = Limiter(
swallow_errors=True, swallow_errors=True,
key_func=get_remote_address, key_func=get_remote_address,
default_limits=[ default_limits=[
'1/second', '2/second',
'20/minute', '30/minute',
'300/hour' '400/hour'
]) ])
app.state.limiter = limiter app.state.limiter = limiter

View file

@ -38,7 +38,6 @@ async def respond(
is_chat = False is_chat = False
model = None model = None
is_stream = False
if 'chat/completions' in path: if 'chat/completions' in path:
is_chat = True is_chat = True
@ -73,6 +72,13 @@ async def respond(
provider_name = provider_auth.split('>')[0] provider_name = provider_auth.split('>')[0]
provider_key = provider_auth.split('>')[1] provider_key = provider_auth.split('>')[1]
if provider_key == '--NO_KEY--':
yield await errors.yield_error(500,
'Sorry, our API seems to have issues connecting to our provider(s).',
'This most likely isn\'t your fault. Please try again later.'
)
return
target_request['headers'].update(target_request.get('headers', {})) target_request['headers'].update(target_request.get('headers', {}))
if target_request['method'] == 'GET' and not payload: if target_request['method'] == 'GET' and not payload:
@ -91,12 +97,13 @@ async def respond(
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
connect=1.0, connect=1.0,
total=float(os.getenv('TRANSFER_TIMEOUT', '500')) total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
), )
) as response: ) as response:
is_stream = response.content_type == 'text/event-stream' is_stream = response.content_type == 'text/event-stream'
if response.status == 429: if response.status == 429:
await keymanager.rate_limit_key(provider_name, provider_key) print('[!] rate limit')
# await keymanager.rate_limit_key(provider_name, provider_key)
continue continue
if response.content_type == 'application/json': if response.content_type == 'application/json':
@ -112,12 +119,14 @@ async def respond(
critical_error = True critical_error = True
if critical_error: if critical_error:
print('[!] critical error')
continue continue
if response.ok: if response.ok:
server_json_response = client_json_response server_json_response = client_json_response
else: else:
print('[!] non-ok response', client_json_response)
continue continue
if is_stream: if is_stream:
@ -136,10 +145,6 @@ async def respond(
except Exception as exc: except Exception as exc:
print('[!] exception', exc) print('[!] exception', exc)
if 'too many requests' in str(exc):
#!TODO
pass
continue continue
else: else:

View file

@ -26,6 +26,12 @@ MESSAGES = [
api_endpoint = os.getenv('CHECKS_ENDPOINT', 'http://localhost:2332/v1') api_endpoint = os.getenv('CHECKS_ENDPOINT', 'http://localhost:2332/v1')
async def _response_base_check(response: httpx.Response) -> None:
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
raise ConnectionError(f'API returned an error: {response.json()}') from exc
async def test_server(): async def test_server():
"""Tests if the API server is running.""" """Tests if the API server is running."""
@ -36,7 +42,7 @@ async def test_server():
url=f'{api_endpoint.replace("/v1", "")}', url=f'{api_endpoint.replace("/v1", "")}',
timeout=3 timeout=3
) )
response.raise_for_status() await _response_base_check(response)
assert response.json()['ping'] == 'pong', 'The API did not return a correct response.' assert response.json()['ping'] == 'pong', 'The API did not return a correct response.'
except httpx.ConnectError as exc: except httpx.ConnectError as exc:
@ -63,7 +69,7 @@ async def test_chat_non_stream_gpt4() -> float:
json=json_data, json=json_data,
timeout=10, timeout=10,
) )
response.raise_for_status() await _response_base_check(response)
assert '1337' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.' assert '1337' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.'
return time.perf_counter() - request_start return time.perf_counter() - request_start
@ -86,7 +92,7 @@ async def test_chat_stream_gpt3() -> float:
json=json_data, json=json_data,
timeout=10, timeout=10,
) )
response.raise_for_status() await _response_base_check(response)
chunks = [] chunks = []
resulting_text = '' resulting_text = ''
@ -128,7 +134,7 @@ async def test_image_generation() -> float:
json=json_data, json=json_data,
timeout=10, timeout=10,
) )
response.raise_for_status() await _response_base_check(response)
assert '://' in response.json()['data'][0]['url'] assert '://' in response.json()['data'][0]['url']
return time.perf_counter() - request_start return time.perf_counter() - request_start
@ -166,7 +172,7 @@ async def test_function_calling():
json=json_data, json=json_data,
timeout=15, timeout=15,
) )
response.raise_for_status() await _response_base_check(response)
res = response.json() res = response.json()
output = json.loads(res['choices'][0]['message']['function_call']['arguments']) output = json.loads(res['choices'][0]['message']['function_call']['arguments'])
@ -185,7 +191,7 @@ async def test_models():
headers=HEADERS, headers=HEADERS,
timeout=3 timeout=3
) )
response.raise_for_status() await _response_base_check(response)
res = response.json() res = response.json()
all_models = [model['id'] for model in res['data']] all_models = [model['id'] for model in res['data']]
@ -208,20 +214,10 @@ async def demo():
else: else:
raise ConnectionError('API Server is not running.') raise ConnectionError('API Server is not running.')
# print('[lightblue]Checking if function calling works...') for func in [test_chat_non_stream_gpt4, test_chat_stream_gpt3]:
# print(await test_function_calling()) print(f'[*] {func.__name__}')
result = await func()
print('Checking non-streamed chat completions...') print(f'[+] {func.__name__} - {result}')
print(await test_chat_non_stream_gpt4())
print('Checking streamed chat completions...')
print(await test_chat_stream_gpt3())
# print('[lightblue]Checking if image generation works...')
# print(await test_image_generation())
# print('Checking the models endpoint...')
# print(await test_models())
except Exception as exc: except Exception as exc:
print('[red]Error: ' + str(exc)) print('[red]Error: ' + str(exc))