Skip to content
This repository has been archived by the owner on Oct 8, 2024. It is now read-only.

Commit

Permalink
Merge pull request #94 from PThorpe92/cmds
Browse files Browse the repository at this point in the history
fix broken auth/crashing
  • Loading branch information
PThorpe92 authored Nov 9, 2023
2 parents e653d10 + 6c9e1b2 commit 637c71c
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 133 deletions.
30 changes: 15 additions & 15 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Rust

on:
push:
branches: [ "main" ]
branches: ["main"]
pull_request:
branches: [ "main" ]
branches: ["main"]

env:
CARGO_TERM_COLOR: always
Expand All @@ -18,17 +18,17 @@ jobs:
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
- uses: actions/checkout@v3
- name: Set up Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
- name: Install Deps On Linux
if: runner.os == 'Linux'
run: |
- uses: actions/checkout@v3
- name: Set up Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
- name: Install Deps On Linux
if: runner.os == 'Linux'
run: |
sudo apt-get install libxcb-render0-dev libxcb-shape0-dev libxcb-xfixes0-dev
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ lazy_static = "1.4.0"
rusqlite = { version = "0.29.0", features = ["bundled"] }
serde_json = { version = "1.0.108", features = ["std"] }
serde = { version = "1.0.190", features = ["derive"] }
curl = "0.4.44"
curl = { features = ["ntlm", "http2"], version = "0.4.44" }
mockito = "1.2.0"
regex = "1.10.2"
dirs = "5.0.1"
Expand Down
19 changes: 18 additions & 1 deletion src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::database::db::{SavedCommand, SavedKey, DB};
use crate::display::menuopts::OPTION_PADDING_MID;
use crate::display::AppOptions;
use crate::request::command::{CmdOpts, CMD};
use crate::request::curl::Curl;
use crate::request::curl::{Curl, AuthKind};
use crate::screens::screen::Screen;
use crate::Config;
use std::{error, mem};
Expand Down Expand Up @@ -189,6 +189,7 @@ impl<'a> App<'a> {
.as_mut()
.unwrap()
.execute(Some(&mut self.db))

}

pub fn get_saved_keys(&self) -> Result<Vec<SavedKey>, rusqlite::Error> {
Expand Down Expand Up @@ -386,6 +387,8 @@ impl<'a> App<'a> {
AppOptions::TcpKeepAlive => self.command.as_mut().unwrap().set_tcp_keepalive(true),

AppOptions::SaveToken => self.command.as_mut().unwrap().save_token(true),
// Auth will be toggled for all types except for Basic, Bearer and digest
AppOptions::Auth(ref kind) => self.command.as_mut().unwrap().set_auth(kind.clone()),
_ => {}
}
self.opts.push(opt);
Expand All @@ -402,6 +405,20 @@ impl<'a> App<'a> {
if self.should_add_option(&opt) {
self.opts.push(opt.clone());
match opt {
// other options will be set at the input menu
// TODO: Consolidate this garbage spaghetti nonsense
AppOptions::Auth(authkind) => match authkind {
AuthKind::Spnego => {
self.command.as_mut().unwrap().set_auth(authkind);
}
AuthKind::Ntlm => {
self.command.as_mut().unwrap().set_auth(authkind);
}
AuthKind::AwsSigv4 => {
self.command.as_mut().unwrap().set_auth(authkind);
}
_ => {}
}
AppOptions::UnixSocket(socket) => self.command.as_mut().unwrap().set_unix_socket(&socket),

AppOptions::Headers(value) => self.command.as_mut().unwrap().add_headers(value),
Expand Down
10 changes: 4 additions & 6 deletions src/display/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::display::menuopts::{
DISPLAY_OPT_MAX_REC, DISPLAY_OPT_MAX_REDIRECTS, DISPLAY_OPT_REFERRER,
use crate::{
display::menuopts::{DISPLAY_OPT_MAX_REC, DISPLAY_OPT_MAX_REDIRECTS, DISPLAY_OPT_REFERRER},
request::curl::AuthKind,
};

use self::menuopts::{
Expand Down Expand Up @@ -34,7 +35,7 @@ pub enum AppOptions {
SaveCommand,
Response(String),
RecDownload(usize),
Auth(String),
Auth(AuthKind),
SaveToken,
UnixSocket(String),
FollowRedirects,
Expand Down Expand Up @@ -73,9 +74,6 @@ impl AppOptions {
AppOptions::RecDownload(ref mut level) => {
*level = val.parse::<usize>().unwrap();
}
AppOptions::Auth(ref mut auth) => {
*auth = val;
}
AppOptions::UnixSocket(ref mut socket) => {
*socket = val;
}
Expand Down
145 changes: 82 additions & 63 deletions src/request/curl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl Display for AuthKind {
AuthKind::None => write!(f, "None"),
AuthKind::Ntlm => write!(f, "NTLM"),
AuthKind::Basic(login) => write!(f, "Basic: {}", login),
AuthKind::Bearer(token) => write!(f, "Bearer: {}", token),
AuthKind::Bearer(token) => write!(f, "Authorization: Bearer {}", token),
AuthKind::Digest(login) => write!(f, "Digest Auth: {}", login),
AuthKind::AwsSigv4 => write!(f, "AWS SignatureV4"),
AuthKind::Spnego => write!(f, "SPNEGO Auth"),
Expand Down Expand Up @@ -338,74 +338,93 @@ impl<'a> CmdOpts for Curl<'a> {

fn execute(&mut self, mut db: Option<&mut Box<DB>>) -> Result<(), String> {
let mut list = List::new();
curl::init();

// Setup auth if we have it, will return whether we appended to the list
let mut has_headers = self.handle_auth_exec(&mut list);
if self.headers.is_some() {
has_headers = true;
self.headers
.as_ref()
.unwrap()

// Handle headers
if let Some(ref headers) = self.headers {
headers
.iter()
.for_each(|h| list.append(h.as_str()).unwrap());
has_headers = true;
}

// Save command to DB
if self.will_save_command() {
let _ = db.as_mut().unwrap().add_command(
&self.get_command_string(),
serde_json::to_string(&self).unwrap_or(String::from("Error serializing command")),
);
if let Some(ref mut db) = db {
let command_string = &self.get_command_string();
let command_json = serde_json::to_string(&self)
.map_err(|e| format!("Error serializing command: {}", e))?;
if db.add_command(command_string, command_json).is_err() {
println!("Error saving command to DB");
}
}
}
// Save token to DB
if self.will_save_token() {
let _ = db
.unwrap()
.add_key(&self.auth.get_token().unwrap_or_default());
if let Some(ref mut db) = db {
if db
.add_key(&self.auth.get_token().unwrap_or_default())
.is_err()
{
println!("Error saving token to DB");
}
}
}
// We have to append the list of headers all at once
// but if we never appended to the list, we skip this
// Append headers if needed
if has_headers {
self.curl.http_headers(list).unwrap();
self.curl
.http_headers(list)
.map_err(|e| format!("Error setting headers: {:?}", e))?;
}

// If we are uploading a file...
// Upload file if specified
if let Some(ref upload_file) = self.upload_file {
let file = std::fs::File::open(upload_file).unwrap();
let mut buff: Vec<u8> = Vec::new();
let mut reader = std::io::BufReader::new(file);
let _ = reader.read_to_end(&mut buff);
// set connect only + establish connection to the URL
self.curl.connect_only(true).unwrap();
if self.curl.perform().is_ok() {
// Upload the file contents
if self.curl.send(buff.as_slice()).is_ok() {
Ok(())
} else {
Err(String::from("Error with upload"))
}
} else {
Err(String::from("Error making connection"))
}
} else if self.curl.perform().is_ok() {
let contents = self.curl.get_ref();
let res = String::from_utf8_lossy(&contents.0);
if let Ok(json) =
serde_json::from_str::<serde_json::Value>(&String::from_utf8_lossy(&contents.0))
{
self.resp = Some(serde_json::to_string_pretty(&json).unwrap());
Ok(())
} else {
self.resp = Some(res.to_string());
Ok(())
if let Ok(file) = std::fs::File::open(upload_file) {
let mut buff: Vec<u8> = Vec::new();
let mut reader = std::io::BufReader::new(file);
reader
.read_to_end(&mut buff)
.map_err(|e| format!("Error reading file: {}", e))?;

// set connect only + establish connection to the URL
self.curl
.connect_only(true)
.map_err(|e| format!("Error connecting: {:?}", e))?;

// Handle upload errors
self.curl
.perform()
.map_err(|err| format!("Error making connection: {:?}", err))?;
self.curl
.send(buff.as_slice())
.map_err(|e| format!("Error with upload: {}", e))?;
}
}

// Perform the main request
self.curl
.perform()
.map_err(|err| format!("Error: {:?}", err))?;
let contents = self.curl.get_ref();
let res = String::from_utf8_lossy(&contents.0);
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&res) {
self.resp = Some(serde_json::to_string_pretty(&json).unwrap());
} else {
return Err(String::from("Error executing command"));
self.resp = Some(res.to_string());
}
Ok(())
}
}

impl<'a> CurlOpts for Curl<'a> {
fn set_auth(&mut self, auth: AuthKind) {
match auth {
AuthKind::Basic(info) => self.set_basic_auth(info),
AuthKind::Basic(ref info) => self.set_basic_auth(info),
AuthKind::Ntlm => self.set_ntlm_auth(),
AuthKind::Bearer(token) => self.set_bearer_auth(token),
AuthKind::Bearer(ref token) => self.set_bearer_auth(token),
AuthKind::AwsSigv4 => self.set_aws_sigv4_auth(),
AuthKind::Digest(login) => self.set_digest_auth(&login),
AuthKind::Spnego => self.set_spnego_auth(),
Expand All @@ -428,6 +447,7 @@ impl<'a> CurlOpts for Curl<'a> {
_ => {}
}
}

fn set_cert_info(&mut self, opt: bool) {
let flag = CurlFlag::CertInfo(CurlFlagType::CertInfo.get_value(), None);
self.toggle_flag(&flag);
Expand Down Expand Up @@ -622,7 +642,7 @@ impl<'a> Curl<'a> {
> 0
}

// This is a hack because when we deseialize from the DB, we get a curl struct with no curl::Easy
// This is a hack because when we deseialize json from the DB, we get a curl struct with no curl::Easy
// field, so we have to manually add, then set the options one at a time from the opts vector.
// ANY time we get a command from the database to run, we have to call this method first.
pub fn easy_from_opts(&mut self) {
Expand Down Expand Up @@ -725,12 +745,12 @@ impl<'a> Curl<'a> {
let _ = self.curl.http_auth(&Auth::new());
}

pub fn set_basic_auth(&mut self, login: String) {
pub fn set_basic_auth(&mut self, login: &str) {
self.add_flag(CurlFlag::Basic(
CurlFlagType::Basic.get_value(),
Some(login.to_string()),
));
self.auth = AuthKind::Basic(login);
self.auth = AuthKind::Basic(String::from(login));
}

pub fn toggle_flag(&mut self, flag: &CurlFlag<'a>) {
Expand Down Expand Up @@ -805,12 +825,12 @@ impl<'a> Curl<'a> {
self.auth = AuthKind::Ntlm;
}

pub fn set_bearer_auth(&mut self, token: String) {
pub fn set_bearer_auth(&mut self, token: &str) {
self.add_flag(CurlFlag::Bearer(
CurlFlagType::Bearer.get_value(),
Some(format!("Authorization: Bearer {}", token)),
Some(format!("Authorization: Bearer {token}")),
));
self.auth = AuthKind::Bearer(token);
self.auth = AuthKind::Bearer(String::from(token));
}

pub fn show_headers(&mut self, file: &str) {
Expand Down Expand Up @@ -843,9 +863,7 @@ impl<'a> Curl<'a> {
self.cmd = cmd.join(" ").trim().to_string();
}

pub fn handle_auth_exec(&mut self, list: &mut List) -> bool {
// we need to know if we have appended to this list
let mut list_edited = false;
fn handle_auth_exec(&mut self, list: &mut List) -> bool {
match &self.auth {
AuthKind::None => {}
AuthKind::Basic(login) => {
Expand All @@ -855,12 +873,13 @@ impl<'a> Curl<'a> {
self.curl
.password(login.split(':').last().unwrap())
.unwrap();
println!("login: {}", login);
println!("login: {login}");
let _ = self.curl.http_auth(Auth::new().basic(true));
}
AuthKind::Bearer(token) => {
list_edited = true;
let _ = list.append(&format!("Authorization: Bearer {}", token.clone()));
AuthKind::Bearer(ref token) => {
list.append(&format!("Authorization: Bearer {token}"))
.unwrap();
return true;
}
AuthKind::Digest(login) => {
self.curl
Expand All @@ -880,8 +899,8 @@ impl<'a> Curl<'a> {
AuthKind::AwsSigv4 => {
let _ = self.curl.http_auth(Auth::new().aws_sigv4(true));
}
};
list_edited
}
false
}

pub fn url_encode(&mut self, data: &str) {
Expand Down Expand Up @@ -1237,7 +1256,7 @@ mod tests {
fn test_set_basic_auth() {
let mut curl = Curl::new();
let usr_pwd = "username:password";
curl.set_basic_auth(usr_pwd.to_string());
curl.set_basic_auth(usr_pwd);
assert_eq!(curl.opts.len(), 1);
assert!(curl.opts.contains(&CurlFlag::Basic(
CurlFlagType::Basic.get_value(),
Expand Down
Loading

0 comments on commit 637c71c

Please sign in to comment.