mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 18:43:57 +01:00
Improved errors, checking and fixed ratelimit retrying
This commit is contained in:
parent
01aa41b6b1
commit
6ef8441681
|
@ -12,7 +12,7 @@ class KeyManager:
|
||||||
self.conn = AsyncIOMotorClient(os.environ['MONGO_URI'])
|
self.conn = AsyncIOMotorClient(os.environ['MONGO_URI'])
|
||||||
|
|
||||||
async def _get_collection(self, collection_name: str):
|
async def _get_collection(self, collection_name: str):
|
||||||
return self.conn[os.getenv('MONGO_NAME', 'nova-test')][collection_name]
|
return self.conn['nova-core'][collection_name]
|
||||||
|
|
||||||
async def add_key(self, provider: str, key: str, source: str='?'):
|
async def add_key(self, provider: str, key: str, source: str='?'):
|
||||||
db = await self._get_collection('providerkeys')
|
db = await self._get_collection('providerkeys')
|
||||||
|
@ -36,7 +36,7 @@ class KeyManager:
|
||||||
})
|
})
|
||||||
|
|
||||||
if key is None:
|
if key is None:
|
||||||
return ValueError('No keys available for this provider!')
|
return '--NO_KEY--'
|
||||||
|
|
||||||
return key['key']
|
return key['key']
|
||||||
|
|
||||||
|
@ -87,4 +87,4 @@ class KeyManager:
|
||||||
manager = KeyManager()
|
manager = KeyManager()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
asyncio.run(manager.delete_empty_keys())
|
asyncio.run(manager.import_all())
|
||||||
|
|
|
@ -8,8 +8,7 @@ async def error(code: int, message: str, tip: str) -> starlette.responses.Respon
|
||||||
'code': code,
|
'code': code,
|
||||||
'message': message,
|
'message': message,
|
||||||
'tip': tip,
|
'tip': tip,
|
||||||
'website': 'https://nova-oss.com',
|
'powered_by': 'nova-api'
|
||||||
'by': 'NovaOSS/Nova-API'
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
return starlette.responses.Response(status_code=code, content=json.dumps(info))
|
return starlette.responses.Response(status_code=code, content=json.dumps(info))
|
||||||
|
@ -20,5 +19,6 @@ async def yield_error(code: int, message: str, tip: str) -> str:
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
'code': code,
|
'code': code,
|
||||||
'message': message,
|
'message': message,
|
||||||
'tip': tip
|
'tip': tip,
|
||||||
|
'powered_by': 'nova-api'
|
||||||
})
|
})
|
||||||
|
|
|
@ -35,9 +35,9 @@ limiter = Limiter(
|
||||||
swallow_errors=True,
|
swallow_errors=True,
|
||||||
key_func=get_remote_address,
|
key_func=get_remote_address,
|
||||||
default_limits=[
|
default_limits=[
|
||||||
'1/second',
|
'2/second',
|
||||||
'20/minute',
|
'30/minute',
|
||||||
'300/hour'
|
'400/hour'
|
||||||
])
|
])
|
||||||
|
|
||||||
app.state.limiter = limiter
|
app.state.limiter = limiter
|
||||||
|
|
|
@ -38,7 +38,6 @@ async def respond(
|
||||||
is_chat = False
|
is_chat = False
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
is_stream = False
|
|
||||||
|
|
||||||
if 'chat/completions' in path:
|
if 'chat/completions' in path:
|
||||||
is_chat = True
|
is_chat = True
|
||||||
|
@ -73,6 +72,13 @@ async def respond(
|
||||||
provider_name = provider_auth.split('>')[0]
|
provider_name = provider_auth.split('>')[0]
|
||||||
provider_key = provider_auth.split('>')[1]
|
provider_key = provider_auth.split('>')[1]
|
||||||
|
|
||||||
|
if provider_key == '--NO_KEY--':
|
||||||
|
yield await errors.yield_error(500,
|
||||||
|
'Sorry, our API seems to have issues connecting to our provider(s).',
|
||||||
|
'This most likely isn\'t your fault. Please try again later.'
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
target_request['headers'].update(target_request.get('headers', {}))
|
target_request['headers'].update(target_request.get('headers', {}))
|
||||||
|
|
||||||
if target_request['method'] == 'GET' and not payload:
|
if target_request['method'] == 'GET' and not payload:
|
||||||
|
@ -91,12 +97,13 @@ async def respond(
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(
|
||||||
connect=1.0,
|
connect=1.0,
|
||||||
total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
|
total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
|
||||||
),
|
)
|
||||||
) as response:
|
) as response:
|
||||||
is_stream = response.content_type == 'text/event-stream'
|
is_stream = response.content_type == 'text/event-stream'
|
||||||
|
|
||||||
if response.status == 429:
|
if response.status == 429:
|
||||||
await keymanager.rate_limit_key(provider_name, provider_key)
|
print('[!] rate limit')
|
||||||
|
# await keymanager.rate_limit_key(provider_name, provider_key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if response.content_type == 'application/json':
|
if response.content_type == 'application/json':
|
||||||
|
@ -112,12 +119,14 @@ async def respond(
|
||||||
critical_error = True
|
critical_error = True
|
||||||
|
|
||||||
if critical_error:
|
if critical_error:
|
||||||
|
print('[!] critical error')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if response.ok:
|
if response.ok:
|
||||||
server_json_response = client_json_response
|
server_json_response = client_json_response
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
print('[!] non-ok response', client_json_response)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if is_stream:
|
if is_stream:
|
||||||
|
@ -136,10 +145,6 @@ async def respond(
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('[!] exception', exc)
|
print('[!] exception', exc)
|
||||||
if 'too many requests' in str(exc):
|
|
||||||
#!TODO
|
|
||||||
pass
|
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -26,6 +26,12 @@ MESSAGES = [
|
||||||
|
|
||||||
api_endpoint = os.getenv('CHECKS_ENDPOINT', 'http://localhost:2332/v1')
|
api_endpoint = os.getenv('CHECKS_ENDPOINT', 'http://localhost:2332/v1')
|
||||||
|
|
||||||
|
async def _response_base_check(response: httpx.Response) -> None:
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
raise ConnectionError(f'API returned an error: {response.json()}') from exc
|
||||||
|
|
||||||
async def test_server():
|
async def test_server():
|
||||||
"""Tests if the API server is running."""
|
"""Tests if the API server is running."""
|
||||||
|
|
||||||
|
@ -36,7 +42,7 @@ async def test_server():
|
||||||
url=f'{api_endpoint.replace("/v1", "")}',
|
url=f'{api_endpoint.replace("/v1", "")}',
|
||||||
timeout=3
|
timeout=3
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
await _response_base_check(response)
|
||||||
|
|
||||||
assert response.json()['ping'] == 'pong', 'The API did not return a correct response.'
|
assert response.json()['ping'] == 'pong', 'The API did not return a correct response.'
|
||||||
except httpx.ConnectError as exc:
|
except httpx.ConnectError as exc:
|
||||||
|
@ -63,7 +69,7 @@ async def test_chat_non_stream_gpt4() -> float:
|
||||||
json=json_data,
|
json=json_data,
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
await _response_base_check(response)
|
||||||
|
|
||||||
assert '1337' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.'
|
assert '1337' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.'
|
||||||
return time.perf_counter() - request_start
|
return time.perf_counter() - request_start
|
||||||
|
@ -86,7 +92,7 @@ async def test_chat_stream_gpt3() -> float:
|
||||||
json=json_data,
|
json=json_data,
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
await _response_base_check(response)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
resulting_text = ''
|
resulting_text = ''
|
||||||
|
@ -128,7 +134,7 @@ async def test_image_generation() -> float:
|
||||||
json=json_data,
|
json=json_data,
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
await _response_base_check(response)
|
||||||
|
|
||||||
assert '://' in response.json()['data'][0]['url']
|
assert '://' in response.json()['data'][0]['url']
|
||||||
return time.perf_counter() - request_start
|
return time.perf_counter() - request_start
|
||||||
|
@ -166,7 +172,7 @@ async def test_function_calling():
|
||||||
json=json_data,
|
json=json_data,
|
||||||
timeout=15,
|
timeout=15,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
await _response_base_check(response)
|
||||||
|
|
||||||
res = response.json()
|
res = response.json()
|
||||||
output = json.loads(res['choices'][0]['message']['function_call']['arguments'])
|
output = json.loads(res['choices'][0]['message']['function_call']['arguments'])
|
||||||
|
@ -185,7 +191,7 @@ async def test_models():
|
||||||
headers=HEADERS,
|
headers=HEADERS,
|
||||||
timeout=3
|
timeout=3
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
await _response_base_check(response)
|
||||||
res = response.json()
|
res = response.json()
|
||||||
|
|
||||||
all_models = [model['id'] for model in res['data']]
|
all_models = [model['id'] for model in res['data']]
|
||||||
|
@ -208,20 +214,10 @@ async def demo():
|
||||||
else:
|
else:
|
||||||
raise ConnectionError('API Server is not running.')
|
raise ConnectionError('API Server is not running.')
|
||||||
|
|
||||||
# print('[lightblue]Checking if function calling works...')
|
for func in [test_chat_non_stream_gpt4, test_chat_stream_gpt3]:
|
||||||
# print(await test_function_calling())
|
print(f'[*] {func.__name__}')
|
||||||
|
result = await func()
|
||||||
print('Checking non-streamed chat completions...')
|
print(f'[+] {func.__name__} - {result}')
|
||||||
print(await test_chat_non_stream_gpt4())
|
|
||||||
|
|
||||||
print('Checking streamed chat completions...')
|
|
||||||
print(await test_chat_stream_gpt3())
|
|
||||||
|
|
||||||
# print('[lightblue]Checking if image generation works...')
|
|
||||||
# print(await test_image_generation())
|
|
||||||
|
|
||||||
# print('Checking the models endpoint...')
|
|
||||||
# print(await test_models())
|
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('[red]Error: ' + str(exc))
|
print('[red]Error: ' + str(exc))
|
||||||
|
|
Loading…
Reference in a new issue