#![feature(fs_try_exists)]
#![feature(write_all_vectored)]
#![feature(split_array)]
#![feature(file_create_new)]
#![allow(clippy::needless_return)]
#![deny(unsafe_op_in_unsafe_fn)]
#![deny(private_in_public)]
mod par_map;
use std::io::Read;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
use std::io::Write;
use par_map::ParallelMap;
// "kill!" just means "completely drop this value". I. e. we drop it and then shadow it to make sure we will not be able to access it even if it is Copy
macro_rules! kill {
($x:ident) => {
#[allow(dropping_references)]
#[allow(dropping_copy_types)]
::core::mem::drop($x);
let $x = { struct Killed; Killed };
#[allow(clippy::drop_non_drop)]
::core::mem::drop($x);
}
}
struct Chunks {
input: std::fs::File,
chunk_size: usize,
file_offset: usize,
}
impl Iterator for Chunks {
type Item = Vec<u8>;
fn next(&mut self) -> Option<Vec<u8>> {
assert_ne!(self.chunk_size, 0);
let mut buf = vec![0u8; self.chunk_size];
let mut filled: usize = 0;
loop {
let just_read = self.input.read(&mut buf[filled..]).unwrap();
filled += just_read;
assert!(filled <= self.chunk_size);
unsafe { libc::posix_fadvise(
self.input.as_raw_fd(),
self.file_offset.try_into().unwrap(),
just_read.try_into().unwrap(),
libc::POSIX_FADV_DONTNEED,
) };
self.file_offset += just_read;
if just_read == 0 || filled == self.chunk_size {
break;
}
}
if filled == 0 {
return None;
}
[()] = [buf.truncate(filled)];
return Some(buf);
}
}
// https://github.com/rayon-rs/rayon/pull/1071#discussion_r1253435955 (yield) (deadlock?)
// Slightly safer "open" (because of CStr)
unsafe fn open_cstr(pathname: &std::ffi::CStr, flags: std::ffi::c_int, mode: libc::mode_t) -> std::ffi::c_int {
return unsafe { libc::open(pathname.as_ptr(), flags, mode) };
}
unsafe fn linkat_cstr(olddirfd: std::ffi::c_int, oldpath: &std::ffi::CStr, newdirfd: std::ffi::c_int, newpath: &std::ffi::CStr, flags: std::ffi::c_int) -> std::ffi::c_int {
return unsafe { libc::linkat(olddirfd, oldpath.as_ptr(), newdirfd, newpath.as_ptr(), flags) };
}
#[derive(Debug, Clone, Copy)]
enum HashMethod {
Blake2b,
Blake3,
}
fn calc_hash(method: HashMethod, data: &[u8]) -> Vec<u8> {
return match method {
HashMethod::Blake2b => blake2b_simd::blake2b(data).as_bytes().to_vec(),
HashMethod::Blake3 => blake3::hash(data).as_bytes().to_vec(),
};
}
fn option_to_hash(option: &str) -> HashMethod {
return match option {
"--digest=blake2b" => HashMethod::Blake2b,
"--digest=blake3" => HashMethod::Blake3,
_ => panic!(),
};
}
fn main() {
const CHUNKS: &str = "chunks";
const INDEX: &str = "index";
let args = std::env::args().collect::<Vec<String>>();
if args[1] == "make" {
let [_, _, ref o_chunk_size, ref o_store, ref o_zstd, ref o_id, ref o_digest] = *args else { panic!(); };
let chunk_size: usize = o_chunk_size.strip_prefix("--chunk-size=").unwrap().parse().unwrap();
let store = o_store.strip_prefix("--store=").unwrap();
let zstd = o_zstd.strip_prefix("--zstd=").unwrap().parse().unwrap();
let id = o_id.strip_prefix("--id=").unwrap();
let hash_method = option_to_hash(o_digest);
let zeros = vec![0u8; chunk_size];
let zeros_hash = calc_hash(hash_method, &zeros);
let mut index_file = {
let fd = unsafe { open_cstr(&std::ffi::CString::new(format!("{store}/{INDEX}/")).unwrap(), libc::O_TMPFILE | libc::O_WRONLY, 0o666) };
assert_ne!(fd, -1);
unsafe { std::io::BufWriter::new(std::fs::File::from_raw_fd(fd)) }
};
Chunks { input: std::fs::File::open("/dev/stdin").unwrap(), chunk_size, file_offset: 0 }.par_map_for_each(|chunk| {
let raw_hash = if chunk == zeros {
zeros_hash.clone()
} else {
calc_hash(hash_method, &chunk)
};
let hash_str = hex::encode(&raw_hash);
if !std::fs::try_exists(format!("{store}/{CHUNKS}/{hash_str}")).unwrap() {
let fd = unsafe { open_cstr(&std::ffi::CString::new(format!("{store}/{CHUNKS}/")).unwrap(), libc::O_TMPFILE | libc::O_WRONLY, 0o666) };
assert_ne!(fd, -1);
let mut chunk_file = unsafe { std::fs::File::from_raw_fd(fd) };
kill!(fd);
[()] = [chunk_file.write_all_vectored(&mut [
std::io::IoSlice::new(&u64::try_from(chunk.len()).unwrap().to_le_bytes()),
std::io::IoSlice::new(&zstd::bulk::compress(&chunk, zstd).unwrap()),
]).unwrap()];
if unsafe { linkat_cstr(libc::AT_FDCWD, &std::ffi::CString::new(format!("/proc/self/fd/{}", chunk_file.as_raw_fd())).unwrap(), libc::AT_FDCWD, &std::ffi::CString::new(format!("{store}/{CHUNKS}/{hash_str}")).unwrap(), libc::AT_SYMLINK_FOLLOW) } == -1 {
// BUG2022:INFORMATIONAL: https://github.com/rayon-rs/rayon/issues/1069
let err = unsafe { *libc::__errno_location() };
assert_eq!(err, libc::EEXIST);
}
}
return raw_hash;
}, |hash| {
[()] = [index_file.write_all(&hash).unwrap()];
});
let index_file = index_file.into_inner().unwrap();
assert_eq!(unsafe { linkat_cstr(libc::AT_FDCWD, &std::ffi::CString::new(format!("/proc/self/fd/{}", index_file.as_raw_fd())).unwrap(), libc::AT_FDCWD, &std::ffi::CString::new(format!("{store}/{INDEX}/{id}")).unwrap(), libc::AT_SYMLINK_FOLLOW) }, 0);
} else if args[1] == "extract" {
let [_, _, ref o_store, ref o_id, ref o_to, ref o_digest, ref o_check_extracted] = *args else { panic!(); };
let store = o_store.strip_prefix("--store=").unwrap();
let id = o_id.strip_prefix("--id=").unwrap();
let to = o_to.strip_prefix("--to=").unwrap();
let hash_method = option_to_hash(o_digest);
let check_extracted = match &**o_check_extracted {
"--check-extracted=true" => true,
"--check-extracted=false" => false,
_ => panic!(),
};
let mut hash_vec = vec![];
std::fs::File::open(format!("{store}/{INDEX}/{id}")).unwrap().read_to_end(&mut hash_vec).unwrap();
let hash_len = match hash_method {
HashMethod::Blake2b => 64,
HashMethod::Blake3 => 32,
};
let chunks = hash_vec.chunks_exact(hash_len);
assert_eq!(chunks.remainder().len(), 0);
let mut to = std::fs::File::create_new(to).unwrap();
chunks.par_map_for_each(|raw_hash| {
let hash = hex::encode(raw_hash);
let mut data = vec![];
std::fs::File::open(format!("{store}/{CHUNKS}/{hash}")).unwrap().read_to_end(&mut data).unwrap();
let data = data.split_array_ref();
let decompressed_size = usize::try_from(u64::from_le_bytes(*data.0)).unwrap();
let chunk = zstd::bulk::decompress(data.1, decompressed_size).unwrap();
assert_eq!(chunk.len(), decompressed_size);
if check_extracted {
assert_eq!(calc_hash(hash_method, &chunk), raw_hash);
}
kill!(hash);
kill!(data);
kill!(decompressed_size);
return chunk;
}, |chunk| {
[()] = [to.write_all(&chunk).unwrap()];
});
} else {
panic!("Usage");
}
}
use rayon::ScopeFifo;
use std::sync::Arc;
use std::collections::VecDeque;
use std::sync::mpsc::Receiver;
use std::sync::mpsc::sync_channel;
trait LocalOrGlobalScope<'scope>
where
Self: 'scope
{
fn spawn_in_scope<Task>(&self, task: Task)
where
Task: for<'scoperef> FnOnce(&'scoperef Self),
Task: Send + 'scope;
}
impl<'scope> LocalOrGlobalScope<'scope> for ScopeFifo<'scope> {
fn spawn_in_scope<Task>(&self, task: Task)
where
Task: for<'scoperef> FnOnce(&'scoperef Self),
Task: Send + 'scope
{
self.spawn_fifo(task);
}
}
pub struct GlobalScope;
impl LocalOrGlobalScope<'static> for GlobalScope {
fn spawn_in_scope<Task>(&self, task: Task)
where
Task: for<'scoperef> FnOnce(&'scoperef Self),
Task: Send + 'static
{
rayon::spawn_fifo(||task(&GlobalScope));
}
}
pub struct ParallelMapIter<'scoperef, 'scope, SomeScope, InputIter, OutputItem>
where
InputIter: Iterator
{
chans: VecDeque<Receiver<OutputItem>>,
iter: std::iter::Fuse<InputIter>,
scope: &'scoperef SomeScope,
// We have to use "dyn" here, because return_position_impl_trait_in_trait is not yet stable
op: Arc<dyn for<'scoperef2> Fn(&'scoperef2 SomeScope, InputIter::Item) -> OutputItem + Sync + Send + 'scope>,
}
fn push_task<'scoperef, 'scope, SomeScope, InputItem, OutputItem>(
input: InputItem,
scope: &'scoperef SomeScope,
op: Arc<dyn for<'scoperef2> Fn(&'scoperef2 SomeScope, InputItem) -> OutputItem + Sync + Send + 'scope>,
chans: &mut VecDeque<Receiver<OutputItem>>
)
where
InputItem: Send + 'scope,
OutputItem: Send + 'scope,
SomeScope: LocalOrGlobalScope<'scope>,
{
let (send, recv) = sync_channel(1);
scope.spawn_in_scope(|scope|{
send.send(op(scope, input)).unwrap();
drop(send);
drop(op);
});
chans.push_back(recv);
}
fn low_level_par_map<'scoperef, 'scope, SomeScope, InputIter, OutputItem, Op>(iter: InputIter, scope: &'scoperef SomeScope, capacity: usize, op: Op) -> ParallelMapIter<'scoperef, 'scope, SomeScope, InputIter, OutputItem>
where
Op: for<'scoperef2> Fn(&'scoperef2 SomeScope, InputIter::Item) -> OutputItem,
Op: Sync + Send + 'scope,
InputIter::Item: Send + 'scope,
OutputItem: Send + 'scope,
SomeScope: LocalOrGlobalScope<'scope>,
InputIter: Iterator,
{
assert!(capacity >= 1);
let mut chans = VecDeque::new();
let op: Arc<dyn for<'scoperef2> Fn(&'scoperef2 SomeScope, InputIter::Item) -> OutputItem + Sync + Send + 'scope> = Arc::new(op);
let mut iter = iter.fuse();
for _ in 0..(capacity - 1) {
if let Some(input) = iter.next() {
push_task(input, scope, Arc::clone(&op), &mut chans);
} else {
break;
}
}
ParallelMapIter { chans, iter, scope, op }
}
pub trait ParallelMap: Iterator + Sized {
fn par_map_with_scope_and_capacity<'scoperef, 'scope, OutputItem, Op>(self, scope: &'scoperef ScopeFifo<'scope>, capacity: usize, op: Op) -> ParallelMapIter<'scoperef, 'scope, ScopeFifo<'scope>, Self, OutputItem>
where
Op: for<'scoperef2> Fn(&'scoperef2 ScopeFifo<'scope>, Self::Item) -> OutputItem,
Op: Sync + Send + 'scope,
Self::Item: Send + 'scope,
OutputItem: Send + 'scope,
{
low_level_par_map(self, scope, capacity, op)
}
fn par_map_with_scope<'scoperef, 'scope, OutputItem, Op>(self, scope: &'scoperef ScopeFifo<'scope>, op: Op) -> ParallelMapIter<'scoperef, 'scope, ScopeFifo<'scope>, Self, OutputItem>
where
Op: for<'scoperef2> Fn(&'scoperef2 ScopeFifo<'scope>, Self::Item) -> OutputItem,
Op: Sync + Send + 'scope,
Self::Item: Send + 'scope,
OutputItem: Send + 'scope,
{
// We can just call rayon::current_num_threads. Unfortunately, this will not work if the scope was created using rayon::ThreadPool::in_place_scope_fifo. So we have to spawn new task merely to learn number of threads
let (send, recv) = sync_channel(1);
scope.spawn_fifo(|_|{
send.send(rayon::current_num_threads()).unwrap();
drop(send);
});
let capacity = recv.recv().unwrap() * 2;
drop(recv);
self.par_map_with_scope_and_capacity(scope, capacity, op)
}
fn par_map_with_capacity<OutputItem, Op>(self, capacity: usize, op: Op) -> ParallelMapIter<'static, 'static, GlobalScope, Self, OutputItem>
where
Op: Fn(Self::Item) -> OutputItem,
Op: Sync + Send + 'static,
Self::Item: Send + 'static,
OutputItem: Send + 'static,
{
low_level_par_map(self, &GlobalScope, capacity, move |_, input|op(input))
}
fn par_map<OutputItem, Op>(self, op: Op) -> ParallelMapIter<'static, 'static, GlobalScope, Self, OutputItem>
where
Op: Fn(Self::Item) -> OutputItem,
Op: Sync + Send + 'static,
Self::Item: Send + 'static,
OutputItem: Send + 'static,
{
self.par_map_with_capacity(rayon::current_num_threads() * 2, op)
}
fn par_map_for_each_with_capacity<MapOp, ForEachOp, OutputItem>(self, capacity: usize, map_op: MapOp, for_each_op: ForEachOp)
where
MapOp: Fn(Self::Item) -> OutputItem,
MapOp: Sync + Send,
ForEachOp: FnMut(OutputItem),
Self::Item: Send,
OutputItem: Send,
{
rayon::in_place_scope_fifo(|s|{
self.par_map_with_scope_and_capacity(s, capacity, move |_, input|map_op(input)).for_each(for_each_op);
});
}
fn par_map_for_each<MapOp, ForEachOp, OutputItem>(self, map_op: MapOp, for_each_op: ForEachOp)
where
MapOp: Fn(Self::Item) -> OutputItem,
MapOp: Sync + Send,
ForEachOp: FnMut(OutputItem),
Self::Item: Send,
OutputItem: Send,
{
self.par_map_for_each_with_capacity(rayon::current_num_threads() * 2, map_op, for_each_op);
}
}
impl<InputIter: Iterator> ParallelMap for InputIter {
}
impl<'scoperef, 'scope, SomeScope, InputIter, OutputItem> Iterator for ParallelMapIter<'scoperef, 'scope, SomeScope, InputIter, OutputItem>
where
InputIter: Iterator,
InputIter::Item: Send + 'scope,
OutputItem: Send + 'scope,
SomeScope: LocalOrGlobalScope<'scope>,
{
type Item = OutputItem;
fn next(&mut self) -> Option<OutputItem> {
if let Some(input) = self.iter.next() {
push_task(input, self.scope, Arc::clone(&self.op), &mut self.chans);
}
#[allow(clippy::manual_map)]
if let Some(output) = self.chans.pop_front() {
Some(output.recv().unwrap())
} else {
None
}
}
}