4
\$\begingroup\$

To learn about Rust and parsing, I wrote a simple calculator that can solve problems involving the four basic operations (+, -, *, /) as well as exponentiation. For the sake of simplicity and performance, everything is an f64, so you get all of the usual float weirdness like 0.30000000000000004 and 1 / 0 = inf. I would love a tough code review to help me improve my skills.

Test cases

  • 2 ^ 2 ^ 3 => 256
  • (2 ^ 2) ^ 3 => 64
  • 5 ^ -1/2 (-8) => -0.8
  • +1++0.2e2--0.3E3 => 321
  • 5^-(1/2) => Error (this could be tokenized as 5 ^ (0 - (1 / 2)) but that's relatively difficult to implement)

One of my goals was to make it as efficient and robust as possible, so to put my application to the test I fed it a 26 GB file consisting of the string (2/2*2/2*2/2) repeated two billion times. It was able to process it and give the right answer in 958.32 seconds, but I'm sure there's lots of possible optimization by eliminating unnecessary cloning.

Disclaimer

The algorithm was developed mostly by trial and error (I don't actually know anything about parsing...). I did look into the shunting yard algorithm but it seemed like it involves reading the input twice which isn't very efficient.

use std::{error::Error, fs::File, str::Chars, time::Instant};
use memmap2::Mmap;

#[derive(Debug, Clone)]
enum Operator {
    Add,
    Subtract,
    Multiply,
    Divide,
    Power,
}

#[derive(Debug, Clone)]
enum Token {
    Number(f64),
    Operator(Operator),
    OpenBracket,
    CloseBracket,
}

struct Tokens<'a> {
    chars: Chars<'a>,
    prev_token: Option<Token>,
    immediately_return: Option<Token>,
}

impl Iterator for Tokens<'_> {
    type Item = Result<Token, String>;

    fn next(&mut self) -> Option<Self::Item> {
        // Immediately return the value in `immediately_return` if possible.
        if self.immediately_return.is_some() {
            self.prev_token = self.immediately_return.take();
            return Some(Ok(self.prev_token.clone()?));
        }

        'outer: while let Some(c) = self.chars.next() {
            self.prev_token = Some(match c {
                '*' => Token::Operator(Operator::Multiply),
                '/' => Token::Operator(Operator::Divide),
                '^' => Token::Operator(Operator::Power),
                '(' => {
                    // Insert a multiplication if necessary, e.g. 4(2) is the same as 4 * (2).
                    if !matches!(
                        self.prev_token,
                        Some(Token::Operator(_) | Token::OpenBracket) | None
                    ) {
                        self.immediately_return = Some(Token::OpenBracket);
                        Token::Operator(Operator::Multiply)
                    } else {
                        Token::OpenBracket
                    }
                }
                ')' => Token::CloseBracket,
                _ => 'arm: {
                    if c.is_whitespace() {
                        continue 'outer;
                    }

                    if c == '+' || c == '-' {
                        // Ambiguous: this could be an operator or part of a float.
                        // Pick whichever option makes the token stream valid.
                        if matches!(
                            self.prev_token,
                            Some(Token::Number(_) | Token::CloseBracket)
                        ) {
                            break 'arm Token::Operator(if c == '+' {
                                Operator::Add
                            } else {
                                Operator::Subtract
                            });
                        }
                    }

                    let mut chunk = c.to_string();
                    while let Some(lookahead) = self.chars.clone().next() {
                        if !matches!(lookahead, '0'..='9' | 'e' | 'E' | '.') {
                            break;
                        }
                        chunk.push(lookahead);
                        self.chars.next();
                    }

                    match chunk.parse::<f64>() {
                        Ok(value) => {
                            // Insert a multiplication if necessary, e.g. (4)2 is the same as (4) * 2.
                            if matches!(self.prev_token, Some(Token::CloseBracket)) {
                                self.immediately_return = Some(Token::Number(value));
                                Token::Operator(Operator::Multiply)
                            } else {
                                Token::Number(value)
                            }
                        }
                        Err(_) => {
                            return Some(Err(format!(
                                "Encountered unexpected character(s): {chunk}"
                            )));
                        }
                    }
                }
            });
            return Some(Ok(self.prev_token.clone()?));
        }
        None
    }
}

fn tokenize(input_str: &str) -> Tokens {
    Tokens {
        chars: input_str.chars(),
        prev_token: None,
        immediately_return: None,
    }
}

fn calculator(input_str: &str) -> Result<f64, String> {
    let tokens = tokenize(input_str);

    let mut number_stack: Vec<f64> = vec![0.0];
    let mut operator_stack: Vec<Operator> = vec![];
    let mut prev_token = Token::Operator(Operator::Add);
    let mut brackets_depth = 0;

    // Reduce the last two values in number_stack using the operator in operator_stack and return the last operator.
    // Panic if number_stack or operator_stack are too short.
    fn reduce_last(number_stack: &mut Vec<f64>, operator_stack: &mut Vec<Operator>) -> Operator {
        let last_num = number_stack.pop().unwrap();
        let last_idx = number_stack.len() - 1;

        let op = operator_stack.pop().unwrap();
        match op {
            Operator::Add => number_stack[last_idx] += last_num,
            Operator::Subtract => number_stack[last_idx] -= last_num,
            Operator::Multiply => number_stack[last_idx] *= last_num,
            Operator::Divide => number_stack[last_idx] /= last_num,
            Operator::Power => number_stack[last_idx] = number_stack[last_idx].powf(last_num),
        }
        op
    }

    // Loop over tokens, expanding and contracting number_stack and operator_stack as necessary.
    for curr_token in tokens {
        match curr_token {
            Ok(Token::Operator(Operator::Add) | Token::Operator(Operator::Subtract)) => {
                if !matches!(prev_token, Token::Number(_))
                    && !matches!(prev_token, Token::CloseBracket)
                {
                    return Err("Invalid expression (add or subtract)".into());
                }
                loop {
                    let popped_op = reduce_last(&mut number_stack, &mut operator_stack);
                    if matches!(popped_op, Operator::Add | Operator::Subtract) {
                        break;
                    }
                }
            }
            Ok(Token::Operator(Operator::Multiply | Operator::Divide)) => {
                if !matches!(prev_token, Token::Number(_) | Token::CloseBracket) {
                    return Err("Invalid expression (unexpected multiply or divide)".into());
                }
                loop {
                    if matches!(
                        operator_stack.last(),
                        Some(Operator::Add | Operator::Subtract)
                    ) {
                        break;
                    }
                    reduce_last(&mut number_stack, &mut operator_stack);
                }
            }
            Ok(Token::Operator(Operator::Power)) => {
                if !matches!(prev_token, Token::Number(_) | Token::CloseBracket) {
                    return Err("Invalid expression (unexpected power)".into());
                }
            }
            Ok(Token::OpenBracket) => {
                match prev_token {
                    Token::Operator(op) => operator_stack.push(op),
                    Token::OpenBracket => operator_stack.push(Operator::Add),
                    _ => return Err("Invalid expression (unexpected open bracket)".into()),
                }
                number_stack.push(0.0);
                brackets_depth += 1;
            }
            Ok(Token::CloseBracket) => {
                if !matches!(prev_token, Token::Number(_) | Token::CloseBracket) {
                    return Err("Invalid expression (unexpected close bracket)".into());
                }
                brackets_depth -= 1;
                if brackets_depth < 0 {
                    return Err("Invalid expression (mismatched brackets)".into());
                }

                loop {
                    let popped_op = reduce_last(&mut number_stack, &mut operator_stack);
                    if matches!(popped_op, Operator::Add | Operator::Subtract) {
                        break;
                    }
                }
            }
            Ok(Token::Number(value)) => {
                number_stack.push(value);
                match prev_token {
                    Token::OpenBracket => operator_stack.push(Operator::Add),
                    Token::Operator(op) => operator_stack.push(op),
                    _ => return Err("Invalid expression.".into()),
                }
            }
            Err(e) => return Err(e),
        }
        prev_token = curr_token?;
    }

    if matches!(prev_token, Token::Operator(_)) || brackets_depth > 0 {
        return Err("Invalid expression.".into());
    }

    loop {
        let popped_op = reduce_last(&mut number_stack, &mut operator_stack);
        if matches!(popped_op, Operator::Add | Operator::Subtract) {
            break;
        }
    }

    // At this point, there should only be one element in number_stack.
    Ok(number_stack[0])
}

const FILE_NAME: &str = "long_eq2.txt";

fn main() -> Result<(), Box<dyn Error>> {
    let start = Instant::now();
    let file = File::open(FILE_NAME)?;
    let mmap = unsafe { Mmap::map(&file)? };
    let file_content = std::str::from_utf8(&mmap)?;

    let answer = calculator(file_content)?;
    println!("Answer: {}", answer); // should be 1
    println!("Elapsed time: {:.2?}", start.elapsed());
    Ok(())
}
\$\endgroup\$

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Browse other questions tagged or ask your own question.