Merge pull request #1 from RayBytes/main

Multiple Keys integration + Rotational Keys
This commit is contained in:
onlix 2023-06-25 12:02:57 +02:00 committed by GitHub
commit 6aa22e8c55
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 94 additions and 6 deletions

75
api/keys.py Normal file
View file

@ -0,0 +1,75 @@
import json
import os
import random
from pymongo import MongoClient
with open('./config.json', 'r') as file:
config = json.load(file)
class Keys:
# --- START OF CONFIG ---
MONGO_URI = os.getenv('MONGO_URI') or config.get('MONGO_URI')
# --- END OF CONFIG ---
locked_keys = set()
cache = {}
# Initialize MongoDB
client = MongoClient(MONGO_URI)
db = client.get_database('keys_db')
collection = db['keys']
def __init__(self, key: str, model: str):
self.key = key
self.model = model
if not Keys.cache:
self._load_keys()
def _load_keys(self) -> None:
cursor = Keys.collection.find({}, {'_id': 0, 'key_value': 1, 'model': 1})
for doc in cursor:
key_value = doc['key_value']
model = doc['model']
Keys.cache.setdefault(model, set()).add(key_value)
def lock(self) -> None:
self.locked_keys.add(self.key)
def unlock(self) -> None:
self.locked_keys.remove(self.key)
def is_locked(self) -> bool:
return self.key in self.locked_keys
@staticmethod
def get(model: str) -> str:
key_candidates = list(Keys.cache.get(model, set()))
random.shuffle(key_candidates)
for key_candidate in key_candidates:
key = Keys(key_candidate, model)
if not key.is_locked():
key.lock()
return key.key
print(f"[WARN] No unlocked keys found for model '{model}' in get keys request!")
def delete(self) -> None:
Keys.collection.delete_one({'key_value': self.key, 'model': self.model})
# Update cache
try:
Keys.cache[self.model].remove(self.key)
except KeyError:
print(f"[WARN] Tried to remove a key from cache which was not present: {self.key}")
def save(self) -> None:
Keys.collection.insert_one({'key_value': self.key, 'model': self.model})
# Update cache
Keys.cache.setdefault(self.model, set()).add(self.key)
# Usage example:
# os.environ['MONGO_URI'] = "mongodb://localhost:27017"
# key_instance = Keys("example_key", "example_model")
# key_instance.save()
# key_value = Keys.get("example_model")

View file

@ -1,6 +1,7 @@
import os import os
import httpx import httpx
import fastapi import fastapi
from keys import Keys
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
@ -31,33 +32,43 @@ async def startup_event():
security.enable_proxy() security.enable_proxy()
security.ip_protection_check() security.ip_protection_check()
# Setup key cache
Keys()
@app.get('/') @app.get('/')
async def root(): async def root():
"""Returns the root endpoint.""" """Returns the root endpoint."""
return { return {
'status': 'ok', 'status': 'ok',
'discord': 'https://discord.gg/mX9BYdFeQF', 'discord': 'https://discord.gg/85gdcd57Xr',
'github': 'https://github.com/Luna-OSS' 'github': 'https://github.com/Luna-OSS'
} }
async def _reverse_proxy(request: Request): async def _reverse_proxy(request: Request):
target_url = f'https://api.openai.com/v1/{request.url.path}' target_url = f'https://api.openai.com/v1/{request.url.path}'
key = Keys.get(request.body()['model'])
if not key:
return fastapi.responses.JSONResponse(
status_code=400,
content={
'error': 'No API Key for model given, please try again with a valid model.'
}
)
request_to_api = target_api_client.build_request( request_to_api = target_api_client.build_request(
method=request.method, method=request.method,
url=target_url, url=target_url,
headers={ headers={
'Authorization': 'Bearer ' + os.getenv('OPENAI_KEY'), 'Authorization': 'Bearer ' + key,
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
content=await request.body(), content=await request.body(),
) )
api_response = await target_api_client.send(request_to_api, stream=True) api_response = await target_api_client.send(request_to_api, stream=True)
print(f'[{request.method}] {request.url.path} {api_response.status_code}') print(f'[{request.method}] {request.url.path} {api_response.status_code}')
Keys(key).unlock()
return StreamingResponse( return StreamingResponse(
api_response.aiter_raw(), api_response.aiter_raw(),
status_code=api_response.status_code, status_code=api_response.status_code,

3
config.json Normal file
View file

@ -0,0 +1,3 @@
{
"MONGO_URI": ""
}

View file

@ -18,4 +18,3 @@ setuptools.setup(
'Operating System :: OS Independent', 'Operating System :: OS Independent',
] ]
) )