diff --git a/Cargo.toml b/Cargo.toml index 950547f..621cc0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nova-python" -version = "0.1.2" +version = "0.1.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -12,3 +12,4 @@ crate-type = ["cdylib"] pyo3 = "0.19.0" reqwest = "0.11.18" tokio = { version = "1.29.1", features = ["rt-multi-thread", "time"] } +serde_json = "1.0.104" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 42182db..e44ce55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,7 +54,7 @@ impl NovaClient { } // Used to make a request to the Nova API - fn make_request(&self, endpoint: Endpoints, model: Models, data: Vec>, seconds_until_timeout: Option) -> PyResult { + fn make_request(&self, endpoint: Endpoints, model: Models, data: Vec>, seconds_until_timeout: Option) -> PyResult { if !model_is_compatible(&endpoint, &model) { return Err(NovaClient::get_endpoint_not_compatible_error()); } @@ -64,13 +64,13 @@ impl NovaClient { let rt = tokio::runtime::Runtime::new().unwrap(); let seconds_until_timeout = match seconds_until_timeout { - Some(seconds_until_timeout) => seconds_until_timeout.parse::().unwrap(), + Some(seconds_until_timeout) => seconds_until_timeout, None => 30 }; - let response: Result = rt.block_on(async { + let unmatched_response: Result = rt.block_on(async { let client = reqwest::Client::builder() - .timeout(time::Duration::from_secs(seconds_until_timeout)) + .timeout(time::Duration::from_secs(seconds_until_timeout as u64)) .user_agent("Mozilla/5.0") .build() .unwrap(); @@ -86,9 +86,25 @@ impl NovaClient { Ok(text) }); - match response { - Ok(response) => Ok(response), - Err(response) => Err(pyo3::exceptions::PyRuntimeError::new_err(response.to_string())) + let reponse = match unmatched_response { + Ok(unmatched_response) => Ok(unmatched_response), + Err(unmatched_response) => Err(pyo3::exceptions::PyRuntimeError::new_err(unmatched_response.to_string())) + }.unwrap(); + + if endpoint == Endpoints::ChatCompletion { + let final_reponse = Python::with_gil(|py| { + let reponse = ChatResponse::new(reponse); + reponse.into_py(py) + }); + + return Ok(final_reponse); + + } else { + let final_reponse = Python::with_gil(|py| { + reponse.into_py(py) + }); + + return Ok(final_reponse); } } @@ -181,9 +197,36 @@ impl NovaClient { fn get_invalid_model_error() -> PyErr { pyo3::exceptions::PyValueError::new_err("Invalid model") } +} - fn get_request_failed_error() -> PyErr { - pyo3::exceptions::PyRuntimeError::new_err("Request failed for unknown reasons.") +#[pyclass(module = "nova_python", frozen)] +struct ChatResponse { + json: String, +} + +#[pymethods] +impl ChatResponse { + #[new] + fn new(json: String) -> Self { + ChatResponse { + json + } + } + + fn get_message_content(&self) -> PyResult { + let json = &self.json; + let json: serde_json::Value = serde_json::from_str(json).unwrap(); + + let content = json["choices"][0]["message"]["content"].as_str().unwrap(); + Ok(content.trim().to_string()) + } + + fn __str__(&self) -> PyResult { + Ok(self.json.clone()) + } + + fn __repr__(&self) -> PyResult { + Ok(format!("ChatResponse(json={})", self.json)) } }