Skip to content

Commit

Permalink
Merge pull request #34 from edgenai/fix/issue33
Browse files Browse the repository at this point in the history
Fix/issue33
  • Loading branch information
toschoo authored Feb 6, 2024
2 parents 5793f16 + 126b4d4 commit 0a9c28f
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 69 deletions.
66 changes: 22 additions & 44 deletions crates/edgen_core/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,30 @@ pub static SETTINGS: Lazy<RwLock<StaticSettings>> = Lazy::new(Default::default);
pub static PROJECT_DIRS: Lazy<ProjectDirs> =
Lazy::new(|| ProjectDirs::from("com", "EdgenAI", "Edgen").unwrap());
pub static CONFIG_FILE: Lazy<PathBuf> = Lazy::new(|| build_config_file_path());
pub static CHAT_COMPLETIONS_MODEL_DIR: Lazy<PathBuf> =
Lazy::new(|| build_chat_completions_model_dir());
pub static AUDIO_TRANSCRIPTIONS_MODEL_DIR: Lazy<PathBuf> =
Lazy::new(|| build_audio_transcriptions_model_dir());

/// Create project dirs if they don't exist
pub fn create_project_dirs() -> Result<(), std::io::Error> {
pub async fn create_project_dirs() -> Result<(), std::io::Error> {
let config_dir = PROJECT_DIRS.config_dir();
let chat_completions_dir = get_chat_completions_model_dir();
let audio_transcriptions_dir = get_audio_transcriptions_model_dir();

let chat_completions_str = SETTINGS
.read()
.await
.read()
.await
.chat_completions_models_dir
.to_string();

let chat_completions_dir = PathBuf::from(chat_completions_str);

let audio_transcriptions_str = SETTINGS
.read()
.await
.read()
.await
.audio_transcriptions_models_dir
.to_string();

let audio_transcriptions_dir = PathBuf::from(audio_transcriptions_str);

if !config_dir.is_dir() {
std::fs::create_dir_all(&config_dir)?;
Expand All @@ -68,7 +82,7 @@ pub fn create_default_config_file() -> Result<(), std::io::Error> {
return Ok(()); // everything is fine
}

create_project_dirs()?;
std::fs::create_dir_all(&PROJECT_DIRS.config_dir())?;

tokio::runtime::Runtime::new().unwrap().block_on(async {
StaticSettings { inner: None }.init().await.unwrap();
Expand All @@ -82,47 +96,11 @@ pub fn get_config_file_path() -> PathBuf {
CONFIG_FILE.to_path_buf()
}

/// Get path to the directory for chat completion models
pub fn get_chat_completions_model_dir() -> PathBuf {
CHAT_COMPLETIONS_MODEL_DIR.to_path_buf()
}

/// Get path to the directory for audio transcriptions models
pub fn get_audio_transcriptions_model_dir() -> PathBuf {
AUDIO_TRANSCRIPTIONS_MODEL_DIR.to_path_buf()
}

/// Get path to the directory for chat completions models as string
pub fn get_chat_completions_model_dir_as_string() -> String {
get_chat_completions_model_dir()
.into_os_string()
.into_string()
.unwrap()
}

/// Get path to the directory for audio transcriptions models as string
pub fn get_audio_transcriptions_model_dir_as_string() -> String {
get_audio_transcriptions_model_dir()
.into_os_string()
.into_string()
.unwrap()
}

fn build_config_file_path() -> PathBuf {
let config_dir = PROJECT_DIRS.config_dir();
config_dir.join(Path::new("edgen.conf.yaml"))
}

fn build_chat_completions_model_dir() -> PathBuf {
let data_dir = PROJECT_DIRS.data_dir();
data_dir.join(Path::new("models/chat/completions"))
}

fn build_audio_transcriptions_model_dir() -> PathBuf {
let data_dir = PROJECT_DIRS.data_dir();
data_dir.join(Path::new("models/audio/transcriptions"))
}

#[derive(Error, Debug, Serialize)]
pub enum SettingsError {
#[error("failed to read the settings file: {0}")]
Expand Down
6 changes: 2 additions & 4 deletions crates/edgen_server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ pub type EdgenResult = Result<(), String>;

/// Main entry point for the server process
pub fn start(command: &cli::TopLevel) -> EdgenResult {
// if the project dirs do not exist, try to create them.
// if that fails, exit.
settings::create_project_dirs().unwrap();

match &command.subcommand {
None => serve(&cli::Serve::default())?,
Some(cli::Command::Serve(serve_args)) => serve(serve_args)?,
Expand Down Expand Up @@ -160,6 +156,8 @@ async fn start_server(args: &cli::Serve) -> EdgenResult {
.await
.expect("Failed to initialise settings");

settings::create_project_dirs().await.unwrap();

while run_server(args).await {
info!("Settings have been updated, resetting environment")
}
Expand Down
33 changes: 25 additions & 8 deletions crates/edgen_server/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ pub enum ModelError {
pub enum ModelKind {
LLM,
Whisper,
Unknown,
}

enum ModelQuantization {
Expand Down Expand Up @@ -92,23 +91,41 @@ impl Model {
.get(&self.name)
.is_none();
let size = self.get_size(&api).await;
let progress_handle =
status::observe_chat_completions_progress(&self.dir, size, download).await;
let progress_handle = match self.kind {
ModelKind::LLM => {
status::observe_chat_completions_progress(&self.dir, size, download).await
}
ModelKind::Whisper => {
status::observe_audio_transcriptions_progress(&self.dir, size, download).await
}
};

let name = self.name.clone();
let kind = self.kind.clone();
let download_handle = tokio::spawn(async move {
if download {
status::set_chat_completions_download(true).await;
}
match kind {
ModelKind::LLM => status::set_chat_completions_download(true).await,
ModelKind::Whisper => status::set_audio_transcriptions_download(true).await,
}
};

let path = api
.get(&name)
.map_err(move |e| ModelError::API(e.to_string()));

if download {
status::set_chat_completions_progress(100).await;
status::set_chat_completions_download(false).await;
}
match kind {
ModelKind::LLM => {
status::set_chat_completions_progress(100).await;
status::set_chat_completions_download(false).await;
}
ModelKind::Whisper => {
status::set_audio_transcriptions_progress(100).await;
status::set_audio_transcriptions_download(false).await;
}
}
};

return path;
});
Expand Down
32 changes: 19 additions & 13 deletions crates/edgen_server/src/openai_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::path::PathBuf;

use axum::http::StatusCode;
use axum::response::sse::Event;
Expand All @@ -37,7 +38,6 @@ use utoipa::ToSchema;
use uuid::Uuid;

use edgen_core::settings::SETTINGS;
use edgen_core::settings::{get_audio_transcriptions_model_dir, get_chat_completions_model_dir};
use edgen_core::whisper::WhisperEndpointError;

use crate::model::{Model, ModelError, ModelKind};
Expand Down Expand Up @@ -592,6 +592,14 @@ pub async fn chat_completions(
.chat_completions_model_repo
.trim()
.to_string();
let dir = SETTINGS
.read()
.await
.read()
.await
.chat_completions_models_dir
.trim()
.to_string();

// invalid
if model_name.is_empty() {
Expand All @@ -600,12 +608,7 @@ pub async fn chat_completions(
});
}

let mut model = Model::new(
ModelKind::LLM,
&model_name,
&repo,
&get_chat_completions_model_dir(),
);
let mut model = Model::new(ModelKind::LLM, &model_name, &repo, &PathBuf::from(&dir));

model
.preload()
Expand Down Expand Up @@ -753,6 +756,14 @@ pub async fn create_transcription(
.audio_transcriptions_model_repo
.trim()
.to_string();
let dir = SETTINGS
.read()
.await
.read()
.await
.audio_transcriptions_models_dir
.trim()
.to_string();

// invalid
if model_name.is_empty() {
Expand All @@ -762,12 +773,7 @@ pub async fn create_transcription(
});
}

let mut model = Model::new(
ModelKind::Whisper,
&model_name,
&repo,
&get_audio_transcriptions_model_dir(),
);
let mut model = Model::new(ModelKind::Whisper, &model_name, &repo, &PathBuf::from(&dir));

model.preload().await?;

Expand Down

0 comments on commit 0a9c28f

Please sign in to comment.