2
\$\begingroup\$

I tried to implement a limited version of Memcached. This is the first time I used multi-threading / async in rust so I had a lot of trouble implementing the part where we establish connections and read the buffers. I also used tokio because every similar project that I saw was using it but I am not sure I understand the benefits. I also had a lot of problem testing the code using the built in rust tests so so I decided to implement the tests in python (since the tests required speaking to my app as an external process it felt easier to implement the tests in a scripting language). I am very unsure about how I read these frames and everything surrounding connections / multi-threading / async and I would like to request some advice. Any advice would be very welcome. My first goal with this project is to learn as much stuff as possible regarding these systems and how they work. (Also I tried to use pytest but found this approach to be simpler because I did not need the tests to be concurrent + it was easier to startup / clean up (probably me not being familiar enough with the framework))

-- Main.rs --

use const_format::concatcp;
use decoder::{Command, GetCommand, StorageCommand};
use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
    time::{Duration, SystemTime},
};
use tokio::net::{TcpListener, TcpStream};
mod decoder;

const PORT: &'static str = "8012";
const ADDRESS: &'static str = concatcp!("127.0.0.1:", PORT);

#[derive(Debug)]
struct DbEntry {
    value: String,
    flags: u16,
    byte_count: usize,
    valid_until: Option<SystemTime>,
}

struct Database {
    map: Mutex<HashMap<String, DbEntry>>,
}

async fn start_tcp_server() {
    let listener = TcpListener::bind(ADDRESS).await.unwrap();

    let db = Arc::new(Database {
        map: Mutex::new(HashMap::new()),
    });

    loop {
        let db_ref = Arc::clone(&db);
        let (socket, _addr) = listener.accept().await.expect("couldn't get client");
        tokio::spawn(async move {
            handle_connection(socket, db_ref).await;
        });
    }
}

async fn handle_connection(socket: TcpStream, db: Arc<Database>) {
    let mut decoder = decoder::Decoder::new(socket);
    loop {
        let command = decoder.decode().await;
        let response = match command {
            Ok(command) => {
                println!("Command recieved was {:?}", command);
                execute_command(command, &db)
            }
            Err(error) => match error {
                decoder::DecodeError::ConnectionClosed => {
                    println!("connection closed!");
                    return;
                }
                decoder::DecodeError::ParseError(parse_error) => match parse_error {
                    decoder::ParseError::InvalidFormat(client_error) => {
                        format!("CLIENT_ERROR {}\r\n", client_error)
                    }
                    decoder::ParseError::UnknownCommand(_message) => "ERROR\r\n".to_owned(),
                },
            },
        };
        println!("Created response: {}", response);
        decoder.send(response).await;
    }
}

fn execute_command(command: Command, db: &Database) -> String {
    let response = match command {
        Command::Set(set_command) => handle_set(set_command, db),
        Command::Get(get_command) => handle_get(get_command, db),
        Command::Add(add_command) => handle_add(add_command, db),
        Command::Replace(replace_command) => handle_replace(replace_command, db),
        Command::Prepend(prepend_command) => handle_prepend(prepend_command, db),
        Command::Append(append_command) => handle_append(append_command, db),
    };

    return response;
}

fn handle_append(command: StorageCommand, db: &Database) -> String {
    let mut map = db.map.lock().unwrap();

    if is_key_valid(&command.key, &map) {
        let value = map.get_mut(&command.key).expect("Already checked if valid");
        value.byte_count += command.byte_count;
        value.value.push_str(command.payload.as_str());
        return "STORED\r\n".to_string();
    }
    map.remove(&command.key);

    "NOT_STORED\r\n".to_string()
}

fn handle_prepend(command: StorageCommand, db: &Database) -> String {
    let mut map = db.map.lock().unwrap();

    if is_key_valid(&command.key, &map) {
        let value = map.get_mut(&command.key).expect("Already checked if valid");
        value.byte_count += command.byte_count;
        value.value = command.payload + (value.value).as_str();
        return "STORED\r\n".to_string();
    }
    map.remove(&command.key);

    "NOT_STORED\r\n".to_string()
}
fn handle_replace(command: StorageCommand, db: &Database) -> String {
    let mut map = db.map.lock().unwrap();

    if is_key_valid(&command.key, &map) {
        insert_entry(
            command.key,
            command.payload,
            command.flags,
            command.byte_count,
            command.exptime,
            &mut map,
        );
        return "STORED\r\n".to_string();
    }

    map.remove(&command.key);

    "NOT_STORED\r\n".to_string()
}

fn handle_add(command: StorageCommand, db: &Database) -> String {
    let mut map = db.map.lock().unwrap();

    if is_key_valid(&command.key, &map) {
        return "NOT_STORED\r\n".to_string();
    }

    insert_entry(
        command.key,
        command.payload,
        command.flags,
        command.byte_count,
        command.exptime,
        &mut map,
    );
    "STORED\r\n".to_string()
}

fn handle_set(command: StorageCommand, db: &Database) -> String {
    let mut map = db.map.lock().unwrap();

    insert_entry(
        command.key,
        command.payload,
        command.flags,
        command.byte_count,
        command.exptime,
        &mut map,
    );
    "STORED\r\n".to_string()
}

fn handle_get(command: GetCommand, db: &Database) -> String {
    let mut map = db.map.lock().unwrap();
    match map.get(&command.key) {
        Some(entry) => {
            if let Some(valid_until) = entry.valid_until {
                if valid_until < SystemTime::now() {
                    // Key expired
                    map.remove(&command.key);
                    return "END\r\n".to_string();
                }
            }

            format!(
                "VALUE {} {} {}\r\n{}\r\nEND\r\n",
                command.key, entry.flags, entry.byte_count, entry.value
            )
        }
        None => "END\r\n".to_string(),
    }
}

fn is_key_valid(key: &str, map: &HashMap<String, DbEntry>) -> bool {
    // The keys are lazlily evaluated and removed
    if let Some(entry) = map.get(key) {
        if let Some(valid_until) = entry.valid_until {
            if valid_until < SystemTime::now() {
                return false;
            }
        }
        return true;
    }
    false
}

fn calculate_valid_until(exptime: i128) -> Option<SystemTime> {
    if exptime > 0 {
        Some(SystemTime::now() + Duration::new(exptime.try_into().unwrap(), 0))
    } else {
        None
    }
}

fn insert_entry(
    key: String,
    payload: String,
    flags: u16,
    byte_count: usize,
    exptime: i128,
    map: &mut HashMap<String, DbEntry>,
) {
    if exptime < 0 {
        return;
    }

    let valid_until = calculate_valid_until(exptime);
    let entry = DbEntry {
        value: payload,
        flags,
        byte_count,
        valid_until,
    };
    map.insert(key, entry);
}

#[tokio::main]
async fn main() {
    start_tcp_server().await;
}

--decoder.rs--

use tokio::{
    io::{AsyncReadExt, AsyncWriteExt},
    net::TcpStream,
};

#[derive(Debug)]
pub enum Command {
    Get(GetCommand),
    Set(StorageCommand),
    Add(StorageCommand),
    Replace(StorageCommand),
    Prepend(StorageCommand),
    Append(StorageCommand),
}

#[derive(Debug)]
pub struct StorageCommand {
    pub key: String,
    pub flags: u16,
    pub exptime: i128,
    pub byte_count: usize,
    pub no_reply: bool, // I did not implement any functionality for this
    pub payload: String,
}

#[derive(Debug)]
pub struct GetCommand {
    pub key: String,
}

#[derive(Debug)]
pub enum ParseError {
    InvalidFormat(String),
    UnknownCommand(String),
}
#[derive(Debug)]
pub struct ConnectionClosedError {}

pub enum DecodeError {
    ConnectionClosed,
    ParseError(ParseError),
}
struct ConnectionBuffer {
    connection: TcpStream,
    buffer: String,
}

impl ConnectionBuffer {
    async fn read_header(&mut self) -> Result<String, ConnectionClosedError> {
        let data = self.read_until_delimeter().await;
        data
    }

    async fn read_until_delimeter(&mut self) -> Result<String, ConnectionClosedError> {
        while !self.buffer.contains("\r\n") {
            let mut data = vec![0; 1024];

            let n = self
                .connection
                .read(&mut data)
                .await
                .expect("failed to read data from socket");

            // Socket closed
            if n == 0 {
                return Err(ConnectionClosedError {});
            }

            self.buffer.push_str(
                core::str::from_utf8(&data[0..n])
                    .expect("These bytes cannot be converted to UTF-8"),
            );
        }

        let (header, rest) = split_once(&self.buffer, "\r\n");
        self.buffer = rest;
        return Ok(header);
    }

    async fn read_payload(&mut self, byte_count: usize) -> Result<String, ConnectionClosedError> {
        let byte_count = byte_count + 2; //The extra \r\n is not part of the byte count inside the header field

        if self.buffer.len() < byte_count {
            let mut data = vec![0; 1024];

            let n = self
                .connection
                .read(&mut data)
                .await
                .expect("failed to read data from socket");

            // Socket closed
            if n == 0 {
                return Err(ConnectionClosedError {});
            }

            self.buffer.push_str(
                core::str::from_utf8(&data[0..n])
                    .expect("These bytes cannot be converted to UTF-8"),
            );
        }
        let buffer = self.buffer.clone();
        let (payload, rest) = buffer.split_at(byte_count as usize);
        self.buffer = rest.to_string();
        Ok(payload.to_string())
    }
}

pub struct Decoder {
    connection: ConnectionBuffer,
}

impl Decoder {
    pub fn new(connection: TcpStream) -> Self {
        return Self {
            connection: ConnectionBuffer {
                connection: connection,
                buffer: String::new(),
            },
        };
    }

    pub async fn decode(&mut self) -> Result<Command, DecodeError> {
        let header = self.connection.read_header().await;
        let header = match header {
            Ok(contents) => contents,
            Err(_conection_closed) => return Err(DecodeError::ConnectionClosed),
        };
        let parse_result = parse_header(header);
        match parse_result {
            Ok(command) => match command {
                Command::Set(set_command) => {
                    return self
                        .parse_storage_command_payload(set_command)
                        .await
                        .map(|x| Command::Set(x))
                }
                Command::Get(_) => return Ok(command),
                Command::Add(add_command) => {
                    return self
                        .parse_storage_command_payload(add_command)
                        .await
                        .map(|x| Command::Add(x))
                }
                Command::Replace(replace_command) => {
                    return self
                        .parse_storage_command_payload(replace_command)
                        .await
                        .map(|x| Command::Replace(x))
                }
                Command::Prepend(prepend_command) => {
                    return self
                        .parse_storage_command_payload(prepend_command)
                        .await
                        .map(|x| Command::Prepend(x))
                }
                Command::Append(append_command) => {
                    return self
                        .parse_storage_command_payload(append_command)
                        .await
                        .map(|x| Command::Append(x))
                }
            },
            Err(parse_error) => Err(DecodeError::ParseError(parse_error)),
        }
    }

    pub async fn send(&mut self, response: String) -> () {
        let bytes = response.as_bytes();
        self.connection
            .connection
            .write_all(bytes)
            .await
            .expect("failed to write data to socket");
    }

    async fn parse_storage_command_payload(
        &mut self,
        command: StorageCommand,
    ) -> Result<StorageCommand, DecodeError> {
        let payload = self
            .connection
            .read_payload(command.byte_count)
            .await
            .map_err(|_| DecodeError::ConnectionClosed)?;

        let payload = parse_payload(payload, command.byte_count)
            .map_err(|err| DecodeError::ParseError(err))?;

        return Ok(StorageCommand {
            payload: payload,
            ..command
        });
    }
}

fn parse_payload(mut payload: String, byte_count: usize) -> Result<String, ParseError> {
    if payload.ends_with("\r\n") {
        payload.truncate(byte_count as usize);
        return Ok(payload);
    }

    return Err(ParseError::InvalidFormat(
        "Expected \r\n at the end of string".to_string(),
    ));
}

fn parse_header(header: String) -> Result<Command, ParseError> {
    // ? I think we cannot detect missing \r\n in header because we must read untill we see \r\n
    // ? It is up to the client to not mess this up.

    let keywords: Vec<_> = header.split_whitespace().collect();
    let command = keywords
        .get(0)
        .ok_or(ParseError::InvalidFormat("Missing command".to_string()))?;
    let key = keywords
        .get(1)
        .ok_or(ParseError::InvalidFormat("Missing key".to_string()))?
        .to_string();

    match command.to_lowercase().as_str() {
        "get" => Ok(Command::Get(GetCommand { key })),
        "set" | "add" | "replace" | "append" | "prepend" => {
            let storage_command = parse_storage_command(&keywords, key)?;
            match *command {
                "set" => Ok(Command::Set(storage_command)),
                "add" => Ok(Command::Add(storage_command)),
                "replace" => Ok(Command::Replace(storage_command)),
                "append" => Ok(Command::Append(storage_command)),
                "prepend" => Ok(Command::Prepend(storage_command)),
                _ => unreachable!(),
            }
        }
        _ => Err(ParseError::UnknownCommand(command.to_string())),
    }
}

fn parse_storage_command(keywords: &[&str], key: String) -> Result<StorageCommand, ParseError> {
    let flags = parse_field::<u16>(keywords, 2, "flags")?;
    let exptime = parse_field::<i128>(keywords, 3, "exptime")?;
    let byte_count = parse_field::<usize>(keywords, 4, "byte count")?;
    let no_reply = keywords.get(5).is_some(); // If there's a 6th keyword, assume "no_reply"

    Ok(StorageCommand {
        key,
        flags,
        exptime,
        byte_count,
        no_reply,
        payload: "".to_owned(),
    })
}

fn parse_field<T: std::str::FromStr>(
    keywords: &[&str],
    index: usize,
    field_name: &str,
) -> Result<T, ParseError> {
    keywords
        .get(index)
        .ok_or(ParseError::InvalidFormat(format!(
            "{} is missing",
            field_name
        )))?
        .parse::<T>()
        .map_err(|_parse_number_errror| {
            ParseError::InvalidFormat(format!("Expected a valid {} value", field_name))
        })
}

fn split_once(in_string: &str, pat: &str) -> (String, String) {
    let mut splitter = in_string.splitn(2, pat);
    let first = splitter.next().unwrap().to_string();
    let second = splitter.next().unwrap().to_string();
    (first, second)
}

-- test_app.py --

import socket
import subprocess
import threading
import time
import random

HOST = "127.0.0.1"
PORT = 8012


def start_process():
    # Start the application using cargo run
    process = subprocess.Popen(
        ["cargo", "run"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
    time.sleep(3)

    return process  # Test will run after this


def tcp_connection() -> socket.socket:
    """Establish and return a TCP connection."""
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.connect((HOST, PORT))
    return s


def send_command(socket: socket.socket, command: str, data: str = "") -> str:
    """Send a command to the server and return the response."""
    socket.sendall(command.encode("utf-8"))
    if data:
        socket.sendall(data.encode("utf-8"))
    return socket.recv(1024).decode("utf-8")


def assert_response(response: str, expected: str, message: str):
    """Assert that a response matches the expected value."""
    assert (
        response == expected
    ), f"{message}: Expected '{expected}', but got '{response}'"


def test_commands(process):
    try:
        socket = tcp_connection()

        # Test: Set command
        assert_response(
            send_command(socket, "set test 0 0 4\r\n", "1234\r\n"),
            "STORED\r\n",
            "Set command failed",
        )

        # Test: Get command
        assert_response(
            send_command(socket, "get test\r\n"),
            "VALUE test 0 4\r\n1234\r\nEND\r\n",
            "Get command failed for existing key",
        )

        # Test: Get command for non existing
        assert_response(
            send_command(socket, "get non-existing-key\r\n"),
            "END\r\n",
            "Get command failed for non existing key",
        )

        # Test: Unknown command
        assert_response(
            send_command(socket, "unknowncmd test\r\n"),
            "ERROR\r\n",
            "Response failed for unknown command",
        )

        # Test: Add command (new key)
        assert_response(
            send_command(socket, "add newkey 0 0 4\r\n", "data\r\n"),
            "STORED\r\n",
            "Add command failed when adding non existing key",
        )

        # Test: Add command (existing key)
        assert_response(
            send_command(socket, "add newkey 0 0 4\r\n", "data\r\n"),
            "NOT_STORED\r\n",
            "Add command failed when adding existing key",
        )

        # Test: Replace command (existing key)
        assert_response(
            send_command(socket, "replace test 0 0 4\r\n", "john\r\n"),
            "STORED\r\n",
            "Replace command failed when replacing existing key",
        )

        # Verify replaced value using get
        assert_response(
            send_command(socket, "get test\r\n"),
            "VALUE test 0 4\r\njohn\r\nEND\r\n",
            "Getting wrong key after replace",
        )

        # Test: Replace command (non-existing key)
        assert_response(
            send_command(socket, "replace test2 0 0 4\r\n", "data\r\n"),
            "NOT_STORED\r\n",
            "Replace command failed when replacing non-existing key",
        )

        # Test: Append command (existing key)
        assert_response(
            send_command(socket, "append test 0 0 4\r\n", "more\r\n"),
            "STORED\r\n",
            "Append command failed for existing key",
        )
        assert_response(
            send_command(socket, "get test\r\n"),
            "VALUE test 0 8\r\njohnmore\r\nEND\r\n",
            "Get command failed after append",
        )

        # Test: Prepend and Append command (existing key)
        assert_response(
            send_command(socket, "set middle 0 0 4\r\n", "data\r\n"),
            "STORED\r\n",
            "set command failed",
        )

        assert_response(
            send_command(socket, "prepend middle 0 0 3\r\n", "pre\r\n"),
            "STORED\r\n",
            "Prepend command failed for existing key",
        )
        assert_response(
            send_command(socket, "get middle\r\n"),
            "VALUE middle 0 7\r\npredata\r\nEND\r\n",
            "Get command failed after prepend",
        )

        assert_response(
            send_command(socket, "append middle 0 0 3\r\n", "end\r\n"),
            "STORED\r\n",
            "append command failed for existing key",
        )
        assert_response(
            send_command(socket, "get middle\r\n"),
            "VALUE middle 0 10\r\npredataend\r\nEND\r\n",
            "Get command failed after append",
        )

        # Test: Append command (non-existing key)
        assert_response(
            send_command(socket, "append foo 0 0 4\r\n", "test\r\n"),
            "NOT_STORED\r\n",
            "Append command failed for non-existing key",
        )

        # Test: Prepend command (non-existing key)
        assert_response(
            send_command(socket, "prepend foo 0 0 4\r\n", "test\r\n"),
            "NOT_STORED\r\n",
            "Prepend command failed for non-existing key",
        )

        print("---- Test Commands Passed ----")

    finally:
        socket.close()


def test_concurrect(process: subprocess.Popen[bytes]):
    threads = []

    def perform_set():
        try:
            socket = tcp_connection()

            for _ in range(10):
                random_val = random.randint(0, 9)

                assert_response(
                    send_command(socket, "set value 0 0 1\r\n", f"{random_val}\r\n"),
                    "STORED\r\n",
                    "Failed store in concurrent test case",
                )
        finally:
            socket.close()

    for i in range(10):
        t = threading.Thread(target=perform_set)
        threads.append(t)
        t.start()

    # Wait for all threads to complete
    for t in threads:
        t.join()

    print("---- Test Concurrent Passed ----")


def test_expiration(process):
    try:
        socket = tcp_connection()

        # Set a key with a 4-second expiration time
        assert_response(
            send_command(socket, "set tempkey 0 4 5\r\n", "hello\r\n"),
            "STORED\r\n",
            "Set command failed",
        )

        # Retrieve the key before expiration
        assert_response(
            send_command(socket, "get tempkey\r\n"),
            "VALUE tempkey 0 5\r\nhello\r\nEND\r\n",
            "Cannot find temp key before it expires",
        )

        # Wait for expiration
        time.sleep(4.1)

        # Attempt to retrieve the key after expiration
        assert_response(
            send_command(socket, "get tempkey\r\n"),
            "END\r\n",
            "Temp key is not deleted after expiration time",
        )

        print("---- Test Expiration Passed ----")

    finally:
        socket.close()


try:
    process = start_process()
    print("Process Started, starting tests ...")
    test_commands(process)
    test_concurrect(process)
    test_expiration(process)
    print("--- All Tests Passed ---")
finally:
    process.terminate()
    process.wait()
\$\endgroup\$

1 Answer 1

4
\$\begingroup\$

racy test

This is rather suspicious. It's a common smell.

    process = subprocess.Popen(
        ...
    time.sleep(3)

Either it will wait too short a time, leading to wrong results, or too long, needlessly making the tests take longer to run.

Prefer to issue a blocking read() on that pipe. If it's necessary to have have the rust target code issue a message like "Ready!", or its version number, then so be it.

Also, rewrite the # comment to be a proper docstring:

    """Starts the application, using cargo run."""

Perhaps clarify that we're starting the target application, rather than any testing applications.

docstrings

I appreciate that you're trying to consistently use docstrings, I really do. It's a good impulse to follow, a far better one than adopting an approach of never writing them. But consider eliding them when a well-written function signature makes the docstring obvious or redundant.

For example, renaming to def establish_tcp_connection() would render its docstring superfluous -- it's already clear we shall return the new connection. And then consider tweaking the verb to be create_ or just get_.

In contrast, send_command's docstring is helpful -- it points out that we await the server's response.

nit: You might find it convenient to from socket import socket

lint

I appreciate that you're using mypy (or maybe pyright?) to keep your signatures consistent. Good job!

If you use the mypy --strict switch, it will point out things like "please append -> None to the assert_response() signature. And it will ask you to consistently identify the type of process parameters.

nit: Consider using a verb like validate_ or verify_, which doesn't follow quite so literally from the implementation that you happened to use. After all, in future a maintainer might choose to phrase it if response != expected: raise ...

Nice diagnostic! Just what someone will need when they start investigating a bug report or failed test.

test_ prefix

In the python ecosystem this prefix has special meaning, to humans and to software components.

def test_commands(process): ...

def test_concurrect(process: subprocess.Popen[bytes]): ...

def test_expiration(process): ...

We have no from unittest import TestCase, so at first reading I took these functions as being intended for the pytest runner. (Or nose, or the various other runners which adhere to this convention.)

Then the Review Context and the final try: clause explain that we won't be executing under a test framework. To reflect that, please rename these so they don't start with test_.

Or perhaps you'd like to revisit your pytest learning journey. It's not that hard, I promise!

nit, typo: Use the conventional spelling of "concurrent".

nice helper

This is a pretty good design and it works well:

        # Test: Set command
        assert_response(
            send_command(socket, "set test 0 0 4\r\n", "1234\r\n"),
            "STORED\r\n",
            "Set command failed",
        )

We see a great many tests where # description matches that 3rd message parameter. Consider promoting it to be the 1st parameter, so you can DRY up and delete all those redundant # comments.

But maybe we don't need to annotate with descriptions at all? Consider using def test_set_command within a TestCase, and then the combination of self.assertEqual() and the function name offer plenty of explanation for what went wrong. Plus we needn't worry about keeping the message in sync when we alter what's being tested.

Or just assert x == y and let pytest reveal the details if it falls apart.

nit: test_commands() is commendably simple, but maybe it's starting to get a little on the long side?

logging

In your concurrent test, I would like to see detailed BEGIN and END timestamps, please, to convince us the executions truly are overlapping. Also, \$10\$ impresses me as maybe one or two orders of magnitude smaller than desired, if you want to stress the SUT and try to provoke racing issues.

parallel tests

The four-second delay in test_expiration() makes perfect sense. But things like that will slow down your edit-test cycle as additional features and tests accumulate.

Consider using unittest-parallel to run other tests while that long sleep() is happening. Or pytest -j 8, which tries to keep eight cores busy.

__main__ guard

It is customary to protect "non-def" code with

if __name__ == "__main__":

so a module can be safely imported without side effects.

Not a bad habit to get into, though frankly it's unlikely to make much difference here.

Overall, these look like valuable and effective automated tests.

\$\endgroup\$
1
  • \$\begingroup\$ Note that if you generate PDF/HTML/whatever documentation from docstrings, adding redundant docstrings is actually useful; without it the automated documentation generators don't know what human readable text to write, and then it looks like the function is undocumented, which looks unprofessional and/or can be confusing (for example, it might then look like it's not part of the actual API). \$\endgroup\$
    – G. Sliepen
    Commented Nov 17 at 11:24

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.