diff --git a/api/keys.py b/api/keys.py new file mode 100644 index 0000000..5eb978b --- /dev/null +++ b/api/keys.py @@ -0,0 +1,75 @@ +import json +import os +import random +from pymongo import MongoClient + +with open('./config.json', 'r') as file: + config = json.load(file) + +class Keys: + # --- START OF CONFIG --- + MONGO_URI = os.getenv('MONGO_URI') or config.get('MONGO_URI') + # --- END OF CONFIG --- + + locked_keys = set() + cache = {} + + # Initialize MongoDB + client = MongoClient(MONGO_URI) + db = client.get_database('keys_db') + collection = db['keys'] + + def __init__(self, key: str, model: str): + self.key = key + self.model = model + if not Keys.cache: + self._load_keys() + + def _load_keys(self) -> None: + cursor = Keys.collection.find({}, {'_id': 0, 'key_value': 1, 'model': 1}) + for doc in cursor: + key_value = doc['key_value'] + model = doc['model'] + Keys.cache.setdefault(model, set()).add(key_value) + + def lock(self) -> None: + self.locked_keys.add(self.key) + + def unlock(self) -> None: + self.locked_keys.remove(self.key) + + def is_locked(self) -> bool: + return self.key in self.locked_keys + + @staticmethod + def get(model: str) -> str: + key_candidates = list(Keys.cache.get(model, set())) + random.shuffle(key_candidates) + + for key_candidate in key_candidates: + key = Keys(key_candidate, model) + + if not key.is_locked(): + key.lock() + return key.key + + print(f"[WARN] No unlocked keys found for model '{model}' in get keys request!") + + def delete(self) -> None: + Keys.collection.delete_one({'key_value': self.key, 'model': self.model}) + # Update cache + try: + Keys.cache[self.model].remove(self.key) + except KeyError: + print(f"[WARN] Tried to remove a key from cache which was not present: {self.key}") + + def save(self) -> None: + Keys.collection.insert_one({'key_value': self.key, 'model': self.model}) + # Update cache + Keys.cache.setdefault(self.model, set()).add(self.key) + +# Usage example: +# os.environ['MONGO_URI'] = "mongodb://localhost:27017" +# key_instance = Keys("example_key", "example_model") +# key_instance.save() +# key_value = Keys.get("example_model") diff --git a/api/main.py b/api/main.py index c50a0a0..ed04c66 100644 --- a/api/main.py +++ b/api/main.py @@ -1,6 +1,7 @@ import os import httpx import fastapi +from keys import Keys from starlette.requests import Request from starlette.responses import StreamingResponse @@ -31,33 +32,43 @@ async def startup_event(): security.enable_proxy() security.ip_protection_check() + # Setup key cache + Keys() + @app.get('/') async def root(): """Returns the root endpoint.""" return { 'status': 'ok', - 'discord': 'https://discord.gg/mX9BYdFeQF', + 'discord': 'https://discord.gg/85gdcd57Xr', 'github': 'https://github.com/Luna-OSS' } async def _reverse_proxy(request: Request): target_url = f'https://api.openai.com/v1/{request.url.path}' - + key = Keys.get(request.body()['model']) + if not key: + return fastapi.responses.JSONResponse( + status_code=400, + content={ + 'error': 'No API Key for model given, please try again with a valid model.' + } + ) request_to_api = target_api_client.build_request( method=request.method, url=target_url, headers={ - 'Authorization': 'Bearer ' + os.getenv('OPENAI_KEY'), + 'Authorization': 'Bearer ' + key, 'Content-Type': 'application/json' }, content=await request.body(), ) api_response = await target_api_client.send(request_to_api, stream=True) - + print(f'[{request.method}] {request.url.path} {api_response.status_code}') - + Keys(key).unlock() return StreamingResponse( api_response.aiter_raw(), status_code=api_response.status_code, diff --git a/config.json b/config.json new file mode 100644 index 0000000..fb89592 --- /dev/null +++ b/config.json @@ -0,0 +1,3 @@ +{ + "MONGO_URI": "" +} diff --git a/setup.py b/setup.py index 72e5d4c..8c27dd1 100644 --- a/setup.py +++ b/setup.py @@ -18,4 +18,3 @@ setuptools.setup( 'Operating System :: OS Independent', ] ) -