mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 21:13:57 +01:00
Merge pull request #1 from RayBytes/main
Multiple Keys integration + Rotational Keys
This commit is contained in:
commit
6aa22e8c55
75
api/keys.py
Normal file
75
api/keys.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pymongo import MongoClient
|
||||||
|
|
||||||
|
with open('./config.json', 'r') as file:
|
||||||
|
config = json.load(file)
|
||||||
|
|
||||||
|
class Keys:
|
||||||
|
# --- START OF CONFIG ---
|
||||||
|
MONGO_URI = os.getenv('MONGO_URI') or config.get('MONGO_URI')
|
||||||
|
# --- END OF CONFIG ---
|
||||||
|
|
||||||
|
locked_keys = set()
|
||||||
|
cache = {}
|
||||||
|
|
||||||
|
# Initialize MongoDB
|
||||||
|
client = MongoClient(MONGO_URI)
|
||||||
|
db = client.get_database('keys_db')
|
||||||
|
collection = db['keys']
|
||||||
|
|
||||||
|
def __init__(self, key: str, model: str):
|
||||||
|
self.key = key
|
||||||
|
self.model = model
|
||||||
|
if not Keys.cache:
|
||||||
|
self._load_keys()
|
||||||
|
|
||||||
|
def _load_keys(self) -> None:
|
||||||
|
cursor = Keys.collection.find({}, {'_id': 0, 'key_value': 1, 'model': 1})
|
||||||
|
for doc in cursor:
|
||||||
|
key_value = doc['key_value']
|
||||||
|
model = doc['model']
|
||||||
|
Keys.cache.setdefault(model, set()).add(key_value)
|
||||||
|
|
||||||
|
def lock(self) -> None:
|
||||||
|
self.locked_keys.add(self.key)
|
||||||
|
|
||||||
|
def unlock(self) -> None:
|
||||||
|
self.locked_keys.remove(self.key)
|
||||||
|
|
||||||
|
def is_locked(self) -> bool:
|
||||||
|
return self.key in self.locked_keys
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get(model: str) -> str:
|
||||||
|
key_candidates = list(Keys.cache.get(model, set()))
|
||||||
|
random.shuffle(key_candidates)
|
||||||
|
|
||||||
|
for key_candidate in key_candidates:
|
||||||
|
key = Keys(key_candidate, model)
|
||||||
|
|
||||||
|
if not key.is_locked():
|
||||||
|
key.lock()
|
||||||
|
return key.key
|
||||||
|
|
||||||
|
print(f"[WARN] No unlocked keys found for model '{model}' in get keys request!")
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
Keys.collection.delete_one({'key_value': self.key, 'model': self.model})
|
||||||
|
# Update cache
|
||||||
|
try:
|
||||||
|
Keys.cache[self.model].remove(self.key)
|
||||||
|
except KeyError:
|
||||||
|
print(f"[WARN] Tried to remove a key from cache which was not present: {self.key}")
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
Keys.collection.insert_one({'key_value': self.key, 'model': self.model})
|
||||||
|
# Update cache
|
||||||
|
Keys.cache.setdefault(self.model, set()).add(self.key)
|
||||||
|
|
||||||
|
# Usage example:
|
||||||
|
# os.environ['MONGO_URI'] = "mongodb://localhost:27017"
|
||||||
|
# key_instance = Keys("example_key", "example_model")
|
||||||
|
# key_instance.save()
|
||||||
|
# key_value = Keys.get("example_model")
|
19
api/main.py
19
api/main.py
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import httpx
|
import httpx
|
||||||
import fastapi
|
import fastapi
|
||||||
|
from keys import Keys
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
|
@ -31,24 +32,34 @@ async def startup_event():
|
||||||
security.enable_proxy()
|
security.enable_proxy()
|
||||||
security.ip_protection_check()
|
security.ip_protection_check()
|
||||||
|
|
||||||
|
# Setup key cache
|
||||||
|
Keys()
|
||||||
|
|
||||||
@app.get('/')
|
@app.get('/')
|
||||||
async def root():
|
async def root():
|
||||||
"""Returns the root endpoint."""
|
"""Returns the root endpoint."""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'status': 'ok',
|
'status': 'ok',
|
||||||
'discord': 'https://discord.gg/mX9BYdFeQF',
|
'discord': 'https://discord.gg/85gdcd57Xr',
|
||||||
'github': 'https://github.com/Luna-OSS'
|
'github': 'https://github.com/Luna-OSS'
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _reverse_proxy(request: Request):
|
async def _reverse_proxy(request: Request):
|
||||||
target_url = f'https://api.openai.com/v1/{request.url.path}'
|
target_url = f'https://api.openai.com/v1/{request.url.path}'
|
||||||
|
key = Keys.get(request.body()['model'])
|
||||||
|
if not key:
|
||||||
|
return fastapi.responses.JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
'error': 'No API Key for model given, please try again with a valid model.'
|
||||||
|
}
|
||||||
|
)
|
||||||
request_to_api = target_api_client.build_request(
|
request_to_api = target_api_client.build_request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=target_url,
|
url=target_url,
|
||||||
headers={
|
headers={
|
||||||
'Authorization': 'Bearer ' + os.getenv('OPENAI_KEY'),
|
'Authorization': 'Bearer ' + key,
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
content=await request.body(),
|
content=await request.body(),
|
||||||
|
@ -57,7 +68,7 @@ async def _reverse_proxy(request: Request):
|
||||||
api_response = await target_api_client.send(request_to_api, stream=True)
|
api_response = await target_api_client.send(request_to_api, stream=True)
|
||||||
|
|
||||||
print(f'[{request.method}] {request.url.path} {api_response.status_code}')
|
print(f'[{request.method}] {request.url.path} {api_response.status_code}')
|
||||||
|
Keys(key).unlock()
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
api_response.aiter_raw(),
|
api_response.aiter_raw(),
|
||||||
status_code=api_response.status_code,
|
status_code=api_response.status_code,
|
||||||
|
|
3
config.json
Normal file
3
config.json
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"MONGO_URI": ""
|
||||||
|
}
|
Loading…
Reference in a new issue