Jelajahi Sumber

Added basic standard input handling

AvariceLHubris 3 minggu lalu
induk
melakukan
152991dc65
1 mengubah file dengan 97 tambahan dan 71 penghapusan
  1. 97 71
      src/main.rs

+ 97 - 71
src/main.rs

@@ -1,12 +1,11 @@
+use std::io::{self, Read, Write};
 use std::path::Path;
-use std::{io::Write, path::PathBuf};
+use std::path::PathBuf;
 
 use anyhow::{Context, anyhow};
 use clap::Parser;
 use huffman::{cli, hufftree, storage};
 
-// TODO: Add STDIN as input
-
 fn main() -> Result<(), anyhow::Error> {
     let args = cli::Args::parse();
 
@@ -14,100 +13,127 @@ fn main() -> Result<(), anyhow::Error> {
     let outputf = args.output_file;
     let mode = args.mode;
 
-    if !inputf.exists() {
-        return Err(anyhow!("Input file did not exist."));
-    }
+    let is_stdin = inputf == Path::new("-");
 
-    let mode = match mode {
-        Some(mode) => mode,
-        None => determine_mode(&inputf, outputf.as_ref()),
+    // Read all input into memory upfront so we know its size and can inspect content.
+    let input_bytes: Vec<u8> = if is_stdin {
+        let mut buf = Vec::new();
+        io::stdin()
+            .read_to_end(&mut buf)
+            .context("Could not read from stdin.")?;
+        buf
+    } else {
+        if !inputf.exists() {
+            return Err(anyhow!("Input file did not exist."));
+        }
+        std::fs::read(&inputf).context("Could not read input file.")?
     };
+    let in_size = input_bytes.len() as u64;
 
-    let outputf = match outputf {
-        Some(name) => name,
-        None => match mode {
-            cli::Mode::X => {
-                if let Some(ext) = inputf.extension()
-                    && ext.eq("z")
-                {
-                    inputf.with_extension("")
+    let mode = match mode {
+        Some(m) => m,
+        None => {
+            if is_stdin {
+                // No filename to inspect — infer from content validity.
+                if std::str::from_utf8(&input_bytes).is_ok() {
+                    cli::Mode::C
                 } else {
-                    inputf.with_extension("unhuffed")
+                    cli::Mode::X
                 }
+            } else {
+                determine_mode(&inputf, outputf.as_ref())
             }
-            cli::Mode::C => match inputf.extension() {
-                // Check if we have an extension
-                Some(ext) => {
-                    // If so we retrieve it and then add to it a ".z"
-                    let ext = ext
-                        .to_str()
-                        .ok_or(anyhow!("Input file path was not valid unicode."))?;
-                    let new_extension = ext.to_string() + ".z";
-
-                    inputf.with_extension(new_extension)
+        }
+    };
+
+    // None means write to stdout.
+    let output_path: Option<PathBuf> = if is_stdin && outputf.is_none() {
+        None
+    } else {
+        Some(match outputf {
+            Some(p) => p,
+            None => match mode {
+                cli::Mode::X => {
+                    if let Some(ext) = inputf.extension()
+                        && ext.eq("z")
+                    {
+                        inputf.with_extension("")
+                    } else {
+                        inputf.with_extension("unhuffed")
+                    }
                 }
-                // Otherwise just use the native syntax for add '.z'
-                None => inputf.with_extension(".z"),
+                cli::Mode::C => match inputf.extension() {
+                    Some(ext) => {
+                        let ext = ext
+                            .to_str()
+                            .ok_or(anyhow!("Input file path was not valid unicode."))?;
+                        inputf.with_extension(ext.to_string() + ".z")
+                    }
+                    None => inputf.with_extension("z"),
+                },
             },
-        },
+        })
     };
 
-    let working_directory = std::path::Path::new(".");
-    let inputf = working_directory.join(inputf);
-    let in_size = inputf
-        .metadata()
-        .context("Could not get the input file's metadata")?
-        .len();
+    // When the output is stdout, status messages go to stderr to avoid corrupting binary output.
+    macro_rules! status {
+        ($($arg:tt)*) => {
+            if output_path.is_none() {
+                eprintln!($($arg)*);
+            } else {
+                println!($($arg)*);
+            }
+        };
+    }
 
-    let outputf = working_directory.join(outputf);
-    let mut outputf = std::fs::File::create(outputf)?;
+    let mut writer: Box<dyn Write> = match output_path {
+        Some(ref p) => {
+            Box::new(std::fs::File::create(p).context("Could not create output file.")?)
+        }
+        None => Box::new(io::stdout()),
+    };
+
+    status!("Read: {} bytes.", in_size);
 
     match mode {
         cli::Mode::X => {
-            let inputf = std::fs::read(inputf)?;
-            println!("Read: {} bytes.", in_size);
-            println!("Decoding text...");
-            let decoded_text = huffman::storage::read_tree_and_text(&mut &inputf[..]);
-            println!("Decoded!");
+            status!("Decoding text...");
+            let decoded_text = huffman::storage::read_tree_and_text(&mut &input_bytes[..]);
+            status!("Decoded!");
 
-            outputf
+            writer
                 .write_all(decoded_text.as_bytes())
-                .context("Could not write decoded text to output file.")?;
-            let out_size = outputf
-                .metadata()
-                .context("Could not get the input file's metadata")?
-                .len();
-            println!("Stored: {} bytes.", out_size);
+                .context("Could not write decoded text to output.")?;
+
+            let out_size = decoded_text.len() as u64;
+            status!("Stored: {} bytes.", out_size);
             let (compressed, original) = (in_size, out_size);
-            println!(
-                "Compression Ratio: {:.2}.",
-                compressed as f64 / original as f64
-            );
+            status!("Compression Ratio: {:.2}.", compressed as f64 / original as f64);
         }
 
         cli::Mode::C => {
-            let inputf =
-                std::fs::read_to_string(inputf).context("Could not read input file to string.")?;
-            println!("Read: {} bytes.", in_size);
-            println!("Encoding text...");
-            let char_f = huffman::hufftree::base::get_char_frequencies(&inputf);
+            let input_text =
+                String::from_utf8(input_bytes).context("Input is not valid UTF-8.")?;
+            status!("Encoding text...");
+            let char_f = huffman::hufftree::base::get_char_frequencies(&input_text);
 
             let base_tree = huffman::hufftree::base::Hufftree::new(char_f);
             let canonical_tree = hufftree::canonical::CanonicalHufftree::from_tree(base_tree);
 
-            storage::store_tree_and_text(canonical_tree, &mut outputf, &inputf)
+            // Buffer encoded output so we can report its size before writing.
+            let mut out_buf: Vec<u8> = Vec::new();
+            storage::store_tree_and_text(canonical_tree, &mut out_buf, &input_text)
                 .expect("Could not store the tree and text.");
-            println!("Encoded!");
-            let out_size = outputf
-                .metadata()
-                .context("Could not get the input file's metadata")?
-                .len();
-            println!("Stored: {} bytes.", out_size);
+            let out_size = out_buf.len() as u64;
+
+            writer
+                .write_all(&out_buf)
+                .context("Could not write encoded data to output.")?;
+
+            status!("Encoded!");
+            status!("Stored: {} bytes.", out_size);
             let (compressed, original) = (out_size, in_size);
-            println!(
-                "Compression Ratio: {:.2}.",
-                compressed as f64 / original as f64
-            );
+            status!("Compression Ratio: {:.2}.", compressed as f64 / original as f64);
         }
     }