I tried to implement a limited version of Memcached. This is the first time I used multi-threading / async in rust so I had a lot of trouble implementing the part where we establish connections and read the buffers. I also used tokio because every similar project that I saw was using it but I am not sure I understand the benefits. I also had a lot of problem testing the code using the built in rust tests so so I decided to implement the tests in python (since the tests required speaking to my app as an external process it felt easier to implement the tests in a scripting language). I am very unsure about how I read these frames and everything surrounding connections / multi-threading / async and I would like to request some advice. Any advice would be very welcome. My first goal with this project is to learn as much stuff as possible regarding these systems and how they work. (Also I tried to use pytest but found this approach to be simpler because I did not need the tests to be concurrent + it was easier to startup / clean up (probably me not being familiar enough with the framework))
-- Main.rs --
use const_format::concatcp;
use decoder::{Command, GetCommand, StorageCommand};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use tokio::net::{TcpListener, TcpStream};
mod decoder;
const PORT: &'static str = "8012";
const ADDRESS: &'static str = concatcp!("127.0.0.1:", PORT);
#[derive(Debug)]
struct DbEntry {
value: String,
flags: u16,
byte_count: usize,
valid_until: Option<SystemTime>,
}
struct Database {
map: Mutex<HashMap<String, DbEntry>>,
}
async fn start_tcp_server() {
let listener = TcpListener::bind(ADDRESS).await.unwrap();
let db = Arc::new(Database {
map: Mutex::new(HashMap::new()),
});
loop {
let db_ref = Arc::clone(&db);
let (socket, _addr) = listener.accept().await.expect("couldn't get client");
tokio::spawn(async move {
handle_connection(socket, db_ref).await;
});
}
}
async fn handle_connection(socket: TcpStream, db: Arc<Database>) {
let mut decoder = decoder::Decoder::new(socket);
loop {
let command = decoder.decode().await;
let response = match command {
Ok(command) => {
println!("Command recieved was {:?}", command);
execute_command(command, &db)
}
Err(error) => match error {
decoder::DecodeError::ConnectionClosed => {
println!("connection closed!");
return;
}
decoder::DecodeError::ParseError(parse_error) => match parse_error {
decoder::ParseError::InvalidFormat(client_error) => {
format!("CLIENT_ERROR {}\r\n", client_error)
}
decoder::ParseError::UnknownCommand(_message) => "ERROR\r\n".to_owned(),
},
},
};
println!("Created response: {}", response);
decoder.send(response).await;
}
}
fn execute_command(command: Command, db: &Database) -> String {
let response = match command {
Command::Set(set_command) => handle_set(set_command, db),
Command::Get(get_command) => handle_get(get_command, db),
Command::Add(add_command) => handle_add(add_command, db),
Command::Replace(replace_command) => handle_replace(replace_command, db),
Command::Prepend(prepend_command) => handle_prepend(prepend_command, db),
Command::Append(append_command) => handle_append(append_command, db),
};
return response;
}
fn handle_append(command: StorageCommand, db: &Database) -> String {
let mut map = db.map.lock().unwrap();
if is_key_valid(&command.key, &map) {
let value = map.get_mut(&command.key).expect("Already checked if valid");
value.byte_count += command.byte_count;
value.value.push_str(command.payload.as_str());
return "STORED\r\n".to_string();
}
map.remove(&command.key);
"NOT_STORED\r\n".to_string()
}
fn handle_prepend(command: StorageCommand, db: &Database) -> String {
let mut map = db.map.lock().unwrap();
if is_key_valid(&command.key, &map) {
let value = map.get_mut(&command.key).expect("Already checked if valid");
value.byte_count += command.byte_count;
value.value = command.payload + (value.value).as_str();
return "STORED\r\n".to_string();
}
map.remove(&command.key);
"NOT_STORED\r\n".to_string()
}
fn handle_replace(command: StorageCommand, db: &Database) -> String {
let mut map = db.map.lock().unwrap();
if is_key_valid(&command.key, &map) {
insert_entry(
command.key,
command.payload,
command.flags,
command.byte_count,
command.exptime,
&mut map,
);
return "STORED\r\n".to_string();
}
map.remove(&command.key);
"NOT_STORED\r\n".to_string()
}
fn handle_add(command: StorageCommand, db: &Database) -> String {
let mut map = db.map.lock().unwrap();
if is_key_valid(&command.key, &map) {
return "NOT_STORED\r\n".to_string();
}
insert_entry(
command.key,
command.payload,
command.flags,
command.byte_count,
command.exptime,
&mut map,
);
"STORED\r\n".to_string()
}
fn handle_set(command: StorageCommand, db: &Database) -> String {
let mut map = db.map.lock().unwrap();
insert_entry(
command.key,
command.payload,
command.flags,
command.byte_count,
command.exptime,
&mut map,
);
"STORED\r\n".to_string()
}
fn handle_get(command: GetCommand, db: &Database) -> String {
let mut map = db.map.lock().unwrap();
match map.get(&command.key) {
Some(entry) => {
if let Some(valid_until) = entry.valid_until {
if valid_until < SystemTime::now() {
// Key expired
map.remove(&command.key);
return "END\r\n".to_string();
}
}
format!(
"VALUE {} {} {}\r\n{}\r\nEND\r\n",
command.key, entry.flags, entry.byte_count, entry.value
)
}
None => "END\r\n".to_string(),
}
}
fn is_key_valid(key: &str, map: &HashMap<String, DbEntry>) -> bool {
// The keys are lazlily evaluated and removed
if let Some(entry) = map.get(key) {
if let Some(valid_until) = entry.valid_until {
if valid_until < SystemTime::now() {
return false;
}
}
return true;
}
false
}
fn calculate_valid_until(exptime: i128) -> Option<SystemTime> {
if exptime > 0 {
Some(SystemTime::now() + Duration::new(exptime.try_into().unwrap(), 0))
} else {
None
}
}
fn insert_entry(
key: String,
payload: String,
flags: u16,
byte_count: usize,
exptime: i128,
map: &mut HashMap<String, DbEntry>,
) {
if exptime < 0 {
return;
}
let valid_until = calculate_valid_until(exptime);
let entry = DbEntry {
value: payload,
flags,
byte_count,
valid_until,
};
map.insert(key, entry);
}
#[tokio::main]
async fn main() {
start_tcp_server().await;
}
--decoder.rs--
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
#[derive(Debug)]
pub enum Command {
Get(GetCommand),
Set(StorageCommand),
Add(StorageCommand),
Replace(StorageCommand),
Prepend(StorageCommand),
Append(StorageCommand),
}
#[derive(Debug)]
pub struct StorageCommand {
pub key: String,
pub flags: u16,
pub exptime: i128,
pub byte_count: usize,
pub no_reply: bool, // I did not implement any functionality for this
pub payload: String,
}
#[derive(Debug)]
pub struct GetCommand {
pub key: String,
}
#[derive(Debug)]
pub enum ParseError {
InvalidFormat(String),
UnknownCommand(String),
}
#[derive(Debug)]
pub struct ConnectionClosedError {}
pub enum DecodeError {
ConnectionClosed,
ParseError(ParseError),
}
struct ConnectionBuffer {
connection: TcpStream,
buffer: String,
}
impl ConnectionBuffer {
async fn read_header(&mut self) -> Result<String, ConnectionClosedError> {
let data = self.read_until_delimeter().await;
data
}
async fn read_until_delimeter(&mut self) -> Result<String, ConnectionClosedError> {
while !self.buffer.contains("\r\n") {
let mut data = vec![0; 1024];
let n = self
.connection
.read(&mut data)
.await
.expect("failed to read data from socket");
// Socket closed
if n == 0 {
return Err(ConnectionClosedError {});
}
self.buffer.push_str(
core::str::from_utf8(&data[0..n])
.expect("These bytes cannot be converted to UTF-8"),
);
}
let (header, rest) = split_once(&self.buffer, "\r\n");
self.buffer = rest;
return Ok(header);
}
async fn read_payload(&mut self, byte_count: usize) -> Result<String, ConnectionClosedError> {
let byte_count = byte_count + 2; //The extra \r\n is not part of the byte count inside the header field
if self.buffer.len() < byte_count {
let mut data = vec![0; 1024];
let n = self
.connection
.read(&mut data)
.await
.expect("failed to read data from socket");
// Socket closed
if n == 0 {
return Err(ConnectionClosedError {});
}
self.buffer.push_str(
core::str::from_utf8(&data[0..n])
.expect("These bytes cannot be converted to UTF-8"),
);
}
let buffer = self.buffer.clone();
let (payload, rest) = buffer.split_at(byte_count as usize);
self.buffer = rest.to_string();
Ok(payload.to_string())
}
}
pub struct Decoder {
connection: ConnectionBuffer,
}
impl Decoder {
pub fn new(connection: TcpStream) -> Self {
return Self {
connection: ConnectionBuffer {
connection: connection,
buffer: String::new(),
},
};
}
pub async fn decode(&mut self) -> Result<Command, DecodeError> {
let header = self.connection.read_header().await;
let header = match header {
Ok(contents) => contents,
Err(_conection_closed) => return Err(DecodeError::ConnectionClosed),
};
let parse_result = parse_header(header);
match parse_result {
Ok(command) => match command {
Command::Set(set_command) => {
return self
.parse_storage_command_payload(set_command)
.await
.map(|x| Command::Set(x))
}
Command::Get(_) => return Ok(command),
Command::Add(add_command) => {
return self
.parse_storage_command_payload(add_command)
.await
.map(|x| Command::Add(x))
}
Command::Replace(replace_command) => {
return self
.parse_storage_command_payload(replace_command)
.await
.map(|x| Command::Replace(x))
}
Command::Prepend(prepend_command) => {
return self
.parse_storage_command_payload(prepend_command)
.await
.map(|x| Command::Prepend(x))
}
Command::Append(append_command) => {
return self
.parse_storage_command_payload(append_command)
.await
.map(|x| Command::Append(x))
}
},
Err(parse_error) => Err(DecodeError::ParseError(parse_error)),
}
}
pub async fn send(&mut self, response: String) -> () {
let bytes = response.as_bytes();
self.connection
.connection
.write_all(bytes)
.await
.expect("failed to write data to socket");
}
async fn parse_storage_command_payload(
&mut self,
command: StorageCommand,
) -> Result<StorageCommand, DecodeError> {
let payload = self
.connection
.read_payload(command.byte_count)
.await
.map_err(|_| DecodeError::ConnectionClosed)?;
let payload = parse_payload(payload, command.byte_count)
.map_err(|err| DecodeError::ParseError(err))?;
return Ok(StorageCommand {
payload: payload,
..command
});
}
}
fn parse_payload(mut payload: String, byte_count: usize) -> Result<String, ParseError> {
if payload.ends_with("\r\n") {
payload.truncate(byte_count as usize);
return Ok(payload);
}
return Err(ParseError::InvalidFormat(
"Expected \r\n at the end of string".to_string(),
));
}
fn parse_header(header: String) -> Result<Command, ParseError> {
// ? I think we cannot detect missing \r\n in header because we must read untill we see \r\n
// ? It is up to the client to not mess this up.
let keywords: Vec<_> = header.split_whitespace().collect();
let command = keywords
.get(0)
.ok_or(ParseError::InvalidFormat("Missing command".to_string()))?;
let key = keywords
.get(1)
.ok_or(ParseError::InvalidFormat("Missing key".to_string()))?
.to_string();
match command.to_lowercase().as_str() {
"get" => Ok(Command::Get(GetCommand { key })),
"set" | "add" | "replace" | "append" | "prepend" => {
let storage_command = parse_storage_command(&keywords, key)?;
match *command {
"set" => Ok(Command::Set(storage_command)),
"add" => Ok(Command::Add(storage_command)),
"replace" => Ok(Command::Replace(storage_command)),
"append" => Ok(Command::Append(storage_command)),
"prepend" => Ok(Command::Prepend(storage_command)),
_ => unreachable!(),
}
}
_ => Err(ParseError::UnknownCommand(command.to_string())),
}
}
fn parse_storage_command(keywords: &[&str], key: String) -> Result<StorageCommand, ParseError> {
let flags = parse_field::<u16>(keywords, 2, "flags")?;
let exptime = parse_field::<i128>(keywords, 3, "exptime")?;
let byte_count = parse_field::<usize>(keywords, 4, "byte count")?;
let no_reply = keywords.get(5).is_some(); // If there's a 6th keyword, assume "no_reply"
Ok(StorageCommand {
key,
flags,
exptime,
byte_count,
no_reply,
payload: "".to_owned(),
})
}
fn parse_field<T: std::str::FromStr>(
keywords: &[&str],
index: usize,
field_name: &str,
) -> Result<T, ParseError> {
keywords
.get(index)
.ok_or(ParseError::InvalidFormat(format!(
"{} is missing",
field_name
)))?
.parse::<T>()
.map_err(|_parse_number_errror| {
ParseError::InvalidFormat(format!("Expected a valid {} value", field_name))
})
}
fn split_once(in_string: &str, pat: &str) -> (String, String) {
let mut splitter = in_string.splitn(2, pat);
let first = splitter.next().unwrap().to_string();
let second = splitter.next().unwrap().to_string();
(first, second)
}
-- test_app.py --
import socket
import subprocess
import threading
import time
import random
HOST = "127.0.0.1"
PORT = 8012
def start_process():
# Start the application using cargo run
process = subprocess.Popen(
["cargo", "run"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
time.sleep(3)
return process # Test will run after this
def tcp_connection() -> socket.socket:
"""Establish and return a TCP connection."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((HOST, PORT))
return s
def send_command(socket: socket.socket, command: str, data: str = "") -> str:
"""Send a command to the server and return the response."""
socket.sendall(command.encode("utf-8"))
if data:
socket.sendall(data.encode("utf-8"))
return socket.recv(1024).decode("utf-8")
def assert_response(response: str, expected: str, message: str):
"""Assert that a response matches the expected value."""
assert (
response == expected
), f"{message}: Expected '{expected}', but got '{response}'"
def test_commands(process):
try:
socket = tcp_connection()
# Test: Set command
assert_response(
send_command(socket, "set test 0 0 4\r\n", "1234\r\n"),
"STORED\r\n",
"Set command failed",
)
# Test: Get command
assert_response(
send_command(socket, "get test\r\n"),
"VALUE test 0 4\r\n1234\r\nEND\r\n",
"Get command failed for existing key",
)
# Test: Get command for non existing
assert_response(
send_command(socket, "get non-existing-key\r\n"),
"END\r\n",
"Get command failed for non existing key",
)
# Test: Unknown command
assert_response(
send_command(socket, "unknowncmd test\r\n"),
"ERROR\r\n",
"Response failed for unknown command",
)
# Test: Add command (new key)
assert_response(
send_command(socket, "add newkey 0 0 4\r\n", "data\r\n"),
"STORED\r\n",
"Add command failed when adding non existing key",
)
# Test: Add command (existing key)
assert_response(
send_command(socket, "add newkey 0 0 4\r\n", "data\r\n"),
"NOT_STORED\r\n",
"Add command failed when adding existing key",
)
# Test: Replace command (existing key)
assert_response(
send_command(socket, "replace test 0 0 4\r\n", "john\r\n"),
"STORED\r\n",
"Replace command failed when replacing existing key",
)
# Verify replaced value using get
assert_response(
send_command(socket, "get test\r\n"),
"VALUE test 0 4\r\njohn\r\nEND\r\n",
"Getting wrong key after replace",
)
# Test: Replace command (non-existing key)
assert_response(
send_command(socket, "replace test2 0 0 4\r\n", "data\r\n"),
"NOT_STORED\r\n",
"Replace command failed when replacing non-existing key",
)
# Test: Append command (existing key)
assert_response(
send_command(socket, "append test 0 0 4\r\n", "more\r\n"),
"STORED\r\n",
"Append command failed for existing key",
)
assert_response(
send_command(socket, "get test\r\n"),
"VALUE test 0 8\r\njohnmore\r\nEND\r\n",
"Get command failed after append",
)
# Test: Prepend and Append command (existing key)
assert_response(
send_command(socket, "set middle 0 0 4\r\n", "data\r\n"),
"STORED\r\n",
"set command failed",
)
assert_response(
send_command(socket, "prepend middle 0 0 3\r\n", "pre\r\n"),
"STORED\r\n",
"Prepend command failed for existing key",
)
assert_response(
send_command(socket, "get middle\r\n"),
"VALUE middle 0 7\r\npredata\r\nEND\r\n",
"Get command failed after prepend",
)
assert_response(
send_command(socket, "append middle 0 0 3\r\n", "end\r\n"),
"STORED\r\n",
"append command failed for existing key",
)
assert_response(
send_command(socket, "get middle\r\n"),
"VALUE middle 0 10\r\npredataend\r\nEND\r\n",
"Get command failed after append",
)
# Test: Append command (non-existing key)
assert_response(
send_command(socket, "append foo 0 0 4\r\n", "test\r\n"),
"NOT_STORED\r\n",
"Append command failed for non-existing key",
)
# Test: Prepend command (non-existing key)
assert_response(
send_command(socket, "prepend foo 0 0 4\r\n", "test\r\n"),
"NOT_STORED\r\n",
"Prepend command failed for non-existing key",
)
print("---- Test Commands Passed ----")
finally:
socket.close()
def test_concurrect(process: subprocess.Popen[bytes]):
threads = []
def perform_set():
try:
socket = tcp_connection()
for _ in range(10):
random_val = random.randint(0, 9)
assert_response(
send_command(socket, "set value 0 0 1\r\n", f"{random_val}\r\n"),
"STORED\r\n",
"Failed store in concurrent test case",
)
finally:
socket.close()
for i in range(10):
t = threading.Thread(target=perform_set)
threads.append(t)
t.start()
# Wait for all threads to complete
for t in threads:
t.join()
print("---- Test Concurrent Passed ----")
def test_expiration(process):
try:
socket = tcp_connection()
# Set a key with a 4-second expiration time
assert_response(
send_command(socket, "set tempkey 0 4 5\r\n", "hello\r\n"),
"STORED\r\n",
"Set command failed",
)
# Retrieve the key before expiration
assert_response(
send_command(socket, "get tempkey\r\n"),
"VALUE tempkey 0 5\r\nhello\r\nEND\r\n",
"Cannot find temp key before it expires",
)
# Wait for expiration
time.sleep(4.1)
# Attempt to retrieve the key after expiration
assert_response(
send_command(socket, "get tempkey\r\n"),
"END\r\n",
"Temp key is not deleted after expiration time",
)
print("---- Test Expiration Passed ----")
finally:
socket.close()
try:
process = start_process()
print("Process Started, starting tests ...")
test_commands(process)
test_concurrect(process)
test_expiration(process)
print("--- All Tests Passed ---")
finally:
process.terminate()
process.wait()