test-harness.rs

Tests the old vs the new Windows argument parser for Rust, to ensure identical behavior

unlisted ⁨1⁩ ⁨file⁩ 2018-12-10 19:19:25 UTC

test-harness.rs

Raw
use std::collections::VecDeque;
use std::ffi::OsString;
use std::os::windows::ffi::OsStringExt;
use std::slice;
use std::iter;
use std::ptr;

/// Implements the Windows command-line argument parsing algorithm, described at
/// <https://docs.microsoft.com/en-us/previous-versions//17w5ykft(v=vs.85)>.
///
/// Windows includes a function to do this in shell32.dll,
/// but linking with that DLL causes the process to be registered as a GUI application.
/// GUI applications add a bunch of overhead, even if no windows are drawn. See
/// <https://randomascii.wordpress.com/2018/12/03/a-not-called-function-can-cause-a-5x-slowdown/>.
unsafe fn new_parser(lp_cmd_line: *const u16) -> VecDeque<OsString> {
    let parsed_args_list = parse_lp_cmd_line(
        lp_cmd_line as *const u16,
        || current_exe());
    parsed_args_list
}

unsafe fn current_exe() -> OsString {
    let mut exe_name: [u16; 4096] = [0; 4096];
    let ch = GetModuleFileNameW(ptr::null_mut(), &mut exe_name as *mut [u16; 4096] as *mut u16, 4096);
    if ch == 0 {
        OsString::new()
    } else {
        OsString::from_wide(&exe_name[0..ch as usize])
    }
}
unsafe fn parse_lp_cmd_line<F: Fn() -> OsString>(lp_cmd_line: *const u16, exe_name: F)
                                                 -> VecDeque<OsString> {
    const BACKSLASH: u16 = '\\' as u16;
    const QUOTE: u16 = '"' as u16;
    const TAB: u16 = '\t' as u16;
    const SPACE: u16 = ' ' as u16;
    let mut in_quotes = false;
    let mut was_in_quotes = false;
    let mut backslash_count: usize = 0;
    let mut ret_val = VecDeque::new();
    let mut cur = Vec::new();
    if lp_cmd_line.is_null() || *lp_cmd_line == 0 {
        ret_val.push_back(exe_name());
        return ret_val;
    }
    let mut i = 0;
    // The executable name at the beginning is special.
    match *lp_cmd_line {
        // The executable name ends at the next quote mark,
        // no matter what.
        QUOTE => {
            loop {
                i += 1;
                if *lp_cmd_line.offset(i) == 0 {
                    ret_val.push_back(OsString::from_wide(
                        slice::from_raw_parts(lp_cmd_line.offset(1), i as usize - 1)
                    ));
                    return ret_val;
                }
                if *lp_cmd_line.offset(i) == QUOTE {
                    break;
                }
            }
            ret_val.push_back(OsString::from_wide(
                slice::from_raw_parts(lp_cmd_line.offset(1), i as usize - 1)
            ));
            i += 1;
        }
        // Implement quirk: when they say whitespace here,
        // they include the entire ASCII control plane:
        // "However, if lpCmdLine starts with any amount of whitespace, CommandLineToArgvW
        // will consider the first argument to be an empty string. Excess whitespace at the
        // end of lpCmdLine is ignored."
        0...SPACE => {
            ret_val.push_back(OsString::new());
            i += 1;
        },
        // The executable name ends at the next quote mark,
        // no matter what.
        _ => {
            loop {
                i += 1;
                if *lp_cmd_line.offset(i) == 0 {
                    ret_val.push_back(OsString::from_wide(
                        slice::from_raw_parts(lp_cmd_line, i as usize)
                    ));
                    return ret_val;
                }
                if let 0...SPACE = *lp_cmd_line.offset(i) {
                    break;
                }
            }
            ret_val.push_back(OsString::from_wide(
                slice::from_raw_parts(lp_cmd_line, i as usize)
            ));
            i += 1;
        }
    }
    loop {
        let c = *lp_cmd_line.offset(i);
        match c {
            // backslash
            BACKSLASH => {
                backslash_count += 1;
                was_in_quotes = false;
            },
            QUOTE if backslash_count % 2 == 0 => {
                cur.extend(iter::repeat(b'\\' as u16).take(backslash_count / 2));
                backslash_count = 0;
                if was_in_quotes {
                    cur.push('"' as u16);
                    was_in_quotes = false;
                } else {
                    was_in_quotes = in_quotes;
                    in_quotes = !in_quotes;
                }
            }
            QUOTE if backslash_count % 2 != 0 => {
                cur.extend(iter::repeat(b'\\' as u16).take(backslash_count / 2));
                backslash_count = 0;
                was_in_quotes = false;
                cur.push(b'"' as u16);
            }
            SPACE | TAB if !in_quotes => {
                cur.extend(iter::repeat(b'\\' as u16).take(backslash_count));
                if !cur.is_empty() || was_in_quotes {
                    ret_val.push_back(OsString::from_wide(&cur[..]));
                    cur.truncate(0);
                }
                backslash_count = 0;
                was_in_quotes = false;
            }
            0x00 => {
                cur.extend(iter::repeat(b'\\' as u16).take(backslash_count));
                // include empty quoted strings at the end of the arguments list
                if !cur.is_empty() || was_in_quotes || in_quotes {
                    ret_val.push_back(OsString::from_wide(&cur[..]));
                }
                break;
            }
            _ => {
                cur.extend(iter::repeat(b'\\' as u16).take(backslash_count));
                backslash_count = 0;
                was_in_quotes = false;
                cur.push(c);
            }
        }
        i += 1;
    }
    ret_val
}

// The one built into Windows
unsafe fn old_parser(lp_cmd_line: *const u16) -> VecDeque<OsString> {
    let mut ret_val = VecDeque::new();
    let mut num_args = 0;
    let parts = CommandLineToArgvW(lp_cmd_line, &mut num_args);
    if parts.is_null() {
        return ret_val;
    }
    for i in 0..(num_args as isize) {
        let mut len = 0;
        let mut part = *parts.offset(i);
        while *part != 0 { part = part.offset(1); len += 1 };
        let os_string = OsString::from_wide(slice::from_raw_parts(*parts.offset(i), len));
        ret_val.push_back(os_string);
    }
    LocalFree(parts);
    ret_val
}

extern "system" {
    fn CommandLineToArgvW(lpCmdLine: *const u16, pNumArgs: *mut u32) -> *mut *mut u16;
    fn LocalFree(pNumArgs: *mut *mut u16);
    pub fn GetModuleFileNameW(hModule: *mut u32,
                              lpFilename: *mut u16,
                              nSize: u32)
                              -> u32;
}

fn main(){
    /*let ucs_2: [u16; 9] = ['a' as u16, ' ' as u16, '"' as u16, 'a' as u16, '"' as u16, '"' as u16, ' ' as u16, 'a' as u16, 0];
    unsafe {
        let new_result = new_parser(&ucs_2 as *const [u16; 9] as *const u16);
        let old_result = old_parser(&ucs_2 as *const [u16; 9] as *const u16);
        println!("ucs_2={:?}", ucs_2);
        println!("old_result={:?}", old_result);
        println!("new_result={:?}", new_result);
    }*/
    // Test with no executable at the beginning
    for a in 0..0xFF {
        println!("{:x}", a);
        for b in 0..0xFF {
            for c in 0..0xFF {
                for d in 0..0xFF {
                    let ucs_2: [u16; 5] = [a, b, c, d, 0];
                    unsafe {
                        let new_result = new_parser(&ucs_2 as *const [u16; 5] as *const u16);
                        let old_result = old_parser(&ucs_2 as *const [u16; 5] as *const u16);
                        if old_result != new_result {
                            println!("ucs_2={:?}", ucs_2);
                        }
                        assert_eq!(old_result, new_result);
                    }
                }
            }
        }
    }
    // Test with an executable at the beginning
    for a in 0..0xFF {
        println!("{:x}", a);
        for b in 0..0xFF {
            for c in 0..0xFF {
                for d in 0..0xFF {
                    let ucs_2: [u16; 7] = ['a' as u16, ' ' as u16, a, b, c, d, 0];
                    unsafe {
                        let new_result = new_parser(&ucs_2 as *const [u16; 7] as *const u16);
                        let old_result = old_parser(&ucs_2 as *const [u16; 7] as *const u16);
                        if old_result != new_result {
                            println!("ucs_2={:?}", ucs_2);
                        }
                        assert_eq!(old_result, new_result);
                    }
                }
            }
        }
    }
}