diff --git a/tools/viking/src/functions.rs b/tools/viking/src/functions.rs index c044b92c..e8799d93 100644 --- a/tools/viking/src/functions.rs +++ b/tools/viking/src/functions.rs @@ -1,11 +1,13 @@ use crate::repo; use anyhow::{bail, ensure, Context, Result}; +use rayon::prelude::*; use rustc_hash::FxHashMap; use std::{ collections::HashSet, path::{Path, PathBuf}, }; +#[derive(Clone, Debug, PartialEq, Eq)] pub enum Status { Matching, NonMatchingMinor, @@ -28,6 +30,7 @@ impl Status { } } +#[derive(Clone, Debug)] pub struct Info { pub addr: u64, pub size: u32, @@ -41,6 +44,7 @@ impl Info { } } +pub const CSV_HEADER: &[&str] = &["Address", "Quality", "Size", "Name"]; pub const ADDRESS_BASE: u64 = 0x71_0000_0000; fn parse_base_16(value: &str) -> Result { @@ -104,11 +108,7 @@ pub fn get_functions_for_path(csv_path: &Path) -> Result> { if reader.read_record(&mut record)? { // Verify that the CSV has the correct format. ensure!(record.len() == 4, "invalid record; expected 4 fields"); - ensure!( - &record[0] == "Address" - && &record[1] == "Quality" - && &record[2] == "Size" - && &record[3] == "Name", + ensure!(record == *CSV_HEADER, "wrong CSV format; this program only works with the new function list format (added in commit 1d4c815fbae3)" ); line_number += 1; @@ -148,11 +148,38 @@ pub fn get_functions_for_path(csv_path: &Path) -> Result> { Ok(result) } +pub fn write_functions_to_path(csv_path: &Path, functions: &[Info]) -> Result<()> { + let mut writer = csv::Writer::from_path(csv_path)?; + writer.write_record(CSV_HEADER)?; + + for function in functions { + let addr = format!("0x{:016x}", function.addr | ADDRESS_BASE); + let status = match function.status { + Status::Matching => "O", + Status::NonMatchingMinor => "m", + Status::NonMatchingMajor => "M", + Status::NotDecompiled => "U", + Status::Wip => "W", + Status::Library => "L", + } + .to_string(); + let size = format!("{:06}", function.size); + let name = function.name.clone(); + writer.write_record(&[addr, status, size, name])?; + } + + Ok(()) +} + /// Returns a Vec of all known functions in the executable. pub fn get_functions() -> Result> { get_functions_for_path(get_functions_csv_path()?.as_path()) } +pub fn write_functions(functions: &[Info]) -> Result<()> { + write_functions_to_path(get_functions_csv_path()?.as_path(), functions) +} + pub fn make_known_function_map(functions: &[Info]) -> FxHashMap { let mut known_functions = FxHashMap::with_capacity_and_hasher(functions.len(), Default::default()); @@ -177,3 +204,18 @@ pub fn demangle_str(name: &str) -> Result { let options = cpp_demangle::DemangleOptions::new(); Ok(symbol.demangle(&options)?) } + +pub fn find_function_fuzzy<'a>(functions: &'a [Info], name: &str) -> Option<&'a Info> { + functions + .par_iter() + .find_first(|function| function.name == name) + .or_else(|| { + // Comparing the demangled names is more expensive than a simple string comparison, + // so only do this as a last resort. + functions.par_iter().find_first(|function| { + demangle_str(&function.name) + .unwrap_or_else(|_| "".to_string()) + .contains(name) + }) + }) +} diff --git a/tools/viking/src/repo.rs b/tools/viking/src/repo.rs index b08528c6..2530d1d5 100644 --- a/tools/viking/src/repo.rs +++ b/tools/viking/src/repo.rs @@ -18,3 +18,7 @@ pub fn get_repo_root() -> Result { }; } } + +pub fn get_tools_path() -> Result { + Ok(get_repo_root()?.join("tools")) +} diff --git a/tools/viking/src/tools/check.rs b/tools/viking/src/tools/check.rs index 51b2d264..6cab54aa 100644 --- a/tools/viking/src/tools/check.rs +++ b/tools/viking/src/tools/check.rs @@ -1,9 +1,11 @@ use anyhow::bail; +use anyhow::ensure; use anyhow::Context; use anyhow::Result; use capstone as cs; use capstone::arch::BuildsCapstone; use colored::*; +use itertools::Itertools; use rayon::prelude::*; use std::cell::RefCell; use std::sync::atomic::AtomicBool; @@ -11,6 +13,7 @@ use viking::checks::FunctionChecker; use viking::elf; use viking::functions; use viking::functions::Status; +use viking::repo; use viking::ui; use mimalloc::MiMalloc; @@ -156,7 +159,108 @@ fn check_all( } } +fn get_function_to_check_from_args(args: &[String]) -> Result { + let mut maybe_fn_to_check: Vec = args + .iter() + .filter(|s| !s.starts_with("-")) + .map(|s| s.clone()) + .collect(); + + ensure!( + maybe_fn_to_check.len() == 1, + "expected only one function name (one argument that isn't prefixed with '-')" + ); + + Ok(maybe_fn_to_check.remove(0)) +} + +fn check_single( + functions: &[functions::Info], + checker: &FunctionChecker, + orig_elf: &elf::OwnedElf, + decomp_elf: &elf::OwnedElf, + decomp_symtab: &elf::SymbolTableByName, + args: &Vec, +) -> Result<()> { + let fn_to_check = get_function_to_check_from_args(&args)?; + let function = functions::find_function_fuzzy(&functions, &fn_to_check) + .with_context(|| format!("unknown function: {}", ui::format_symbol_name(&fn_to_check)))?; + let name = function.name.as_str(); + + eprintln!("{}", ui::format_symbol_name(name).bold()); + + if matches!(function.status, Status::Library) { + bail!("L functions should not be decompiled"); + } + + let decomp_fn = + elf::get_function_by_name(&decomp_elf, &decomp_symtab, &name).with_context(|| { + format!( + "failed to get decomp function: {}", + ui::format_symbol_name(name) + ) + })?; + + let orig_fn = elf::get_function(&orig_elf, function.addr, function.size as u64)?; + + let maybe_mismatch = checker + .check(&mut make_cs()?, &orig_fn, &decomp_fn) + .with_context(|| format!("checking {}", name))?; + + let mut should_show_diff = args + .iter() + .find(|s| s.as_str() == "--always-diff") + .is_some(); + + if let Some(mismatch) = &maybe_mismatch { + eprintln!("{}\n{}", "mismatch".red().bold(), &mismatch); + should_show_diff = true; + } else { + eprintln!("{}", "OK".green().bold()); + } + + if should_show_diff { + let diff_args = args + .iter() + .filter(|s| s.as_str() != &fn_to_check && s.as_str() != "--always-diff"); + + std::process::Command::new(repo::get_tools_path()?.join("asm-differ").join("diff.py")) + .arg("-I") + .arg("-e") + .arg(name) + .arg(format!("0x{:016x}", function.addr)) + .arg(format!("0x{:016x}", function.addr + function.size as u64)) + .args(diff_args) + .status()?; + } + + let new_status = match maybe_mismatch { + None => Status::Matching, + Some(_) => Status::Wip, + }; + + // Update the function status if needed. + if function.status != new_status { + ui::print_note(&format!( + "changing status from {:?} to {:?}", + function.status, new_status + )); + + let mut new_functions = functions.iter().cloned().collect_vec(); + new_functions + .iter_mut() + .find(|info| info.addr == function.addr) + .unwrap() + .status = new_status; + functions::write_functions(&new_functions)?; + } + + Ok(()) +} + fn main() -> Result<()> { + let args: Vec = std::env::args().skip(1).collect(); + let orig_elf = elf::load_orig_elf().with_context(|| "failed to load original ELF")?; let decomp_elf = elf::load_decomp_elf().with_context(|| "failed to load decomp ELF")?; @@ -192,7 +296,20 @@ fn main() -> Result<()> { ) .with_context(|| "failed to construct FunctionChecker")?; - check_all(&functions, &checker, &orig_elf, &decomp_elf, &decomp_symtab)?; + if args.len() >= 1 { + // Single function mode. + check_single( + &functions, + &checker, + &orig_elf, + &decomp_elf, + &decomp_symtab, + &args, + )?; + } else { + // Normal check mode. + check_all(&functions, &checker, &orig_elf, &decomp_elf, &decomp_symtab)?; + } Ok(()) }