Skip to content

Commit

Permalink
Ai branch (#95)
Browse files Browse the repository at this point in the history
* 👷 add initial AI sdk & tools profiles

* 🚧 profiles: adjust ai packages

* profiles/ai-sdk: Add pattern matching for most modern NVIDIA GPUs (#94)

---------

Co-authored-by: Vasiliy Stelmachenok <[email protected]>
  • Loading branch information
vnepogodin and ventureoo authored Apr 29, 2024
1 parent 2b3eedc commit a70eaec
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 5 deletions.
21 changes: 21 additions & 0 deletions profiles/pci/ai_sdk/profiles.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# VENDOR AMD=1002 INTEL=8086 NVIDIA=10de
# CLASSID 03=Display controller
# 00=VGA compatible controller 02=3D controller 80=Display controller

# NVIDIA cards
#CLASSIDS="0300 0302"
#VENDORIDS="10de"
#DEVICEIDS=">/var/lib/mhwd/ids/pci/nvidia.ids"

[nvidia-ai-sdk]
desc = 'NVIDIA AI SDK and related tools'
nonfree = true
ai_sdk = true
class_ids = "*"
#class_ids = "0300 0380 0302"
vendor_ids = "10de"
priority = 9
packages = 'cuda cudnn nccl python-pytorch-opt-cuda ollama-cuda tensorflow-opt-cuda python-tensorflow-opt-cuda'
device_name_pattern = '(AD|GV|TU|GA|GH|GM|GP)\w+'
#device_ids = '*'

4 changes: 4 additions & 0 deletions src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ pub struct Args {
#[arg(long = "is_nvidia_card")]
pub is_nvidia_card: bool,

/// Toggle AI SDK profiles
#[arg(long = "ai_sdk")]
pub is_ai_sdk: bool,

#[arg(long, default_value_t = String::from("/var/cache/pacman/pkg"))]
pub pmcachedir: String,
#[arg(long, default_value_t = String::from("/etc/pacman.conf"))]
Expand Down
6 changes: 6 additions & 0 deletions src/console_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ pub fn handle_arguments_listing(data: &Data, args: &crate::args::Args) {
}

pub fn list_profiles(profiles: &[Profile], header_msg: &str) {
//if profiles.iter().all(|x| x.is_ai_sdk) {
// return;
//}
print_status(header_msg);
println!();

Expand All @@ -93,6 +96,9 @@ pub fn list_profiles(profiles: &[Profile], header_msg: &str) {
.set_header(vec![&fl!("name-header"), &fl!("nonfree-header")]);

for profile in profiles.iter() {
//if profile.is_ai_sdk {
// continue;
//}
table.add_row(vec![&profile.name, &profile.is_nonfree.to_string()]);
}

Expand Down
19 changes: 16 additions & 3 deletions src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@ pub type ListOfDevicesT = Vec<Device>;
#[derive(Debug, Default)]
pub struct Data {
pub sync_package_manager_database: bool,
pub is_ai_sdk_target: bool,
pub pci_devices: ListOfDevicesT,
pub installed_pci_profiles: ListOfProfilesT,
pub all_pci_profiles: ListOfProfilesT,
pub invalid_profiles: Vec<String>,
}

impl Data {
pub fn new() -> Self {
pub fn new(is_ai_sdk: bool) -> Self {
let mut res = Self {
pci_devices: fill_devices().expect("Failed to init"),
sync_package_manager_database: true,
is_ai_sdk_target: is_ai_sdk,
..Default::default()
};

Expand All @@ -65,14 +67,14 @@ impl Data {
let conf_path = crate::consts::CHWD_PCI_DATABASE_DIR;
let configs = &mut self.installed_pci_profiles;

fill_profiles(configs, &mut self.invalid_profiles, conf_path);
fill_profiles(configs, &mut self.invalid_profiles, conf_path, self.is_ai_sdk_target);
}

fn fill_all_profiles(&mut self) {
let conf_path = crate::consts::CHWD_PCI_CONFIG_DIR;
let configs = &mut self.all_pci_profiles;

fill_profiles(configs, &mut self.invalid_profiles, conf_path);
fill_profiles(configs, &mut self.invalid_profiles, conf_path, self.is_ai_sdk_target);
}

fn update_profiles_data(&mut self) {
Expand All @@ -94,6 +96,7 @@ fn fill_profiles(
configs: &mut ListOfProfilesT,
invalid_profiles: &mut Vec<String>,
conf_path: &str,
is_ai_sdk: bool,
) {
for entry in fs::read_dir(conf_path).expect("Failed to read directory!") {
let config_file_path = format!(
Expand All @@ -109,6 +112,16 @@ fn fill_profiles(
if profile.packages.is_empty() {
continue;
}
// if we dont target ai sdk,
// skip profile marked as ai sdk.
if !is_ai_sdk && profile.is_ai_sdk {
continue;
}
// if we target ai sdk,
// skip profile which isn't marked as ai sdk.
if is_ai_sdk && !profile.is_ai_sdk {
continue;
}
configs.push(profile);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ fn main() -> anyhow::Result<()> {
}

// 2) Initialize
let mut data_obj = data::Data::new();
let mut data_obj = data::Data::new(argstruct.is_ai_sdk);

let missing_dirs = misc::check_environment();
if !missing_dirs.is_empty() {
Expand Down
2 changes: 1 addition & 1 deletion src/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub fn find_profile(profile_name: &str, profiles: &[Profile]) -> Option<Arc<Prof
}

pub fn check_nvidia_card() {
let data = data::Data::new();
let data = data::Data::new(false);
for pci_device in data.pci_devices.iter() {
if pci_device.available_profiles.is_empty() {
continue;
Expand Down
4 changes: 4 additions & 0 deletions src/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub struct HardwareID {
pub struct Profile {
pub is_nonfree: bool,

pub is_ai_sdk: bool,

pub prof_path: String,
pub name: String,
pub desc: String,
Expand Down Expand Up @@ -136,6 +138,7 @@ pub fn get_invalid_profiles(file_path: &str) -> Result<Vec<String>> {
fn parse_profile(node: &toml::Table, profile_name: &str) -> Result<Profile> {
let mut profile = Profile {
is_nonfree: node.get("nonfree").and_then(|x| x.as_bool()).unwrap_or(false),
is_ai_sdk: node.get("ai_sdk").and_then(|x| x.as_bool()).unwrap_or(false),
prof_path: "".to_owned(),
name: profile_name.to_owned(),
packages: node.get("packages").and_then(|x| x.as_str()).unwrap_or("").to_owned(),
Expand Down Expand Up @@ -229,6 +232,7 @@ fn merge_table_left(lhs: &mut toml::Table, rhs: &toml::Table) {
pub fn write_profile_to_file(file_path: &str, profile: &Profile) -> bool {
let mut table = toml::Table::new();
table.insert("nonfree".to_owned(), profile.is_nonfree.into());
table.insert("ai_sdk".to_owned(), profile.is_ai_sdk.into());
table.insert("desc".to_owned(), profile.desc.clone().into());
table.insert("packages".to_owned(), profile.packages.clone().into());
table.insert("priority".to_owned(), profile.priority.into());
Expand Down

0 comments on commit a70eaec

Please sign in to comment.