Skip to content

Commit

Permalink
feature: support redis-server network layer
Browse files Browse the repository at this point in the history
  • Loading branch information
runningwater committed Aug 1, 2024
1 parent c09d28b commit fd8be0b
Show file tree
Hide file tree
Showing 10 changed files with 700 additions and 16 deletions.
436 changes: 436 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ enum_dispatch = "0.3.13"
lazy_static = "1.5.0"
# This library provides a convenient derive macro for the standard library’s std::error::Error trait.
thiserror = "1.0.63"
tokio = { version = "1.39.2", features = ["rt", "net", "macros", "fs", "rt-multi-thread"] }
tokio-util = { version = "0.7.11", features = ["codec"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tokio-stream = "0.1.15"
futures = { version = "0.3.30", default-features = false }
2 changes: 1 addition & 1 deletion src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl Backend {
.get(key)
.and_then(|v| v.get(field).map(|v| v.value().clone()))
}
pub fn hset(&mut self, key: String, field: String, value: RespFrame) {
pub fn hset(&self, key: String, field: String, value: RespFrame) {
let hmap = self.hmap.entry(key).or_default();
hmap.insert(field, value);
}
Expand Down
88 changes: 86 additions & 2 deletions src/cmd/hmap.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,49 @@
use crate::cmd::{extract_args, validate_command, HGetAll, HSet};
use crate::cmd::{extract_args, validate_command, CommandExecutor, HGetAll, HSet, RESP_OK};
use crate::{
cmd::{CommandError, HGet},
RespArray, RespFrame,
Backend, BulkString, RespArray, RespFrame, RespNull,
};

//=================== 实现 CommandExecutor trait for Command
impl CommandExecutor for HGet {
fn execute(self, backend: &Backend) -> RespFrame {
backend
.hget(&self.key, &self.field)
.unwrap_or(RespFrame::Null(RespNull))
}
}
impl CommandExecutor for HGetAll {
fn execute(self, backend: &Backend) -> RespFrame {
let hmap = backend.hgetall(&self.key);
match hmap {
Some(hmap) => {
let mut data = Vec::with_capacity(hmap.len());
for v in hmap.iter() {
let key = v.key().to_owned();
data.push((key, v.value().clone()));
}
if self.sort {
data.sort_by(|a, b| a.0.cmp(&b.0));
}
let ret = data
.into_iter()
.flat_map(|(k, v)| vec![BulkString::new(k.as_bytes()).into(), v])
.collect::<Vec<RespFrame>>();

RespArray::new(ret).into()
}
None => RespFrame::Null(RespNull),
}
}
}
impl CommandExecutor for HSet {
fn execute(self, backend: &Backend) -> RespFrame {
backend.hset(self.key, self.field, self.value);
RESP_OK.clone()
}
}

//=================== 实现 TryFrom trait for Command
impl TryFrom<RespArray> for HGet {
type Error = CommandError;

Expand All @@ -30,6 +70,7 @@ impl TryFrom<RespArray> for HGetAll {
match args.next() {
Some(RespFrame::BulkString(key)) => Ok(HGetAll {
key: String::from_utf8(key.0)?,
sort: false,
}),
_ => Err(CommandError::InvalidCommand(
"Invalid key for HGETALL command".to_string(),
Expand Down Expand Up @@ -108,6 +149,49 @@ mod tests {
RespFrame::BulkString(BulkString::new(b"myvalue".to_vec()))
);

Ok(())
}
#[test]
fn test_hset_hgetall_commands() -> Result<()> {
let backend = Backend::new();
let set_cmd = HSet {
key: "mykey".to_string(),
field: "myfield".to_string(),
value: RespFrame::BulkString(BulkString::new(b"myvalue".to_vec())),
};
let set_result = set_cmd.execute(&backend);
assert_eq!(set_result, RESP_OK.clone());

let get_cmd = HGet {
key: "mykey".to_string(),
field: "myfield".to_string(),
};
let get_result = get_cmd.execute(&backend);
assert_eq!(
get_result,
RespFrame::BulkString(BulkString::new(b"myvalue".to_vec()))
);

let set_cmd = HSet {
key: "mykey".to_string(),
field: "hello".to_string(),
value: RespFrame::BulkString(BulkString::new(b"world".to_vec())),
};
set_cmd.execute(&backend);

let getall_cmd = HGetAll {
key: "mykey".to_string(),
sort: true,
};
let getall_result = getall_cmd.execute(&backend);
let expected_result = RespArray::new(vec![
BulkString::new(b"hello".to_vec()).into(),
BulkString::new(b"world".to_vec()).into(),
BulkString::new(b"myfield".to_vec()).into(),
BulkString::new(b"myvalue".to_vec()).into(),
]);
assert_eq!(getall_result, expected_result.into());

Ok(())
}
}
6 changes: 3 additions & 3 deletions src/cmd/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use crate::{

//=================== 实现 CommandExecutor trait for Command
impl CommandExecutor for Get {
fn execute(&self, backend: &Backend) -> RespFrame {
fn execute(self, backend: &Backend) -> RespFrame {
backend.get(&self.key).unwrap_or(RespFrame::Null(RespNull))
}
}
impl CommandExecutor for Set {
fn execute(&self, backend: &Backend) -> RespFrame {
fn execute(self, backend: &Backend) -> RespFrame {
backend.set(self.key.clone(), self.value.clone());
RESP_OK.clone()
}
Expand Down Expand Up @@ -79,7 +79,7 @@ mod tests {
buf.extend_from_slice(b"*3\r\n$3\r\nset\r\n$5\r\nhello\r\n$5\r\nworld\r\n");

let frame = RespArray::decode(&mut buf)?;
let result: Set = frame.try_into()?;
let result: Set = frame.try_into()?; // Set::try_from(frame)
assert_eq!(result.key, "hello");
assert_eq!(
result.value,
Expand Down
57 changes: 49 additions & 8 deletions src/cmd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use enum_dispatch::enum_dispatch;
use lazy_static::lazy_static;
use thiserror::Error;
use tracing::info;

use crate::{Backend, RespArray, RespError, RespFrame, SimpleString};

Expand Down Expand Up @@ -30,53 +32,92 @@ pub enum CommandError {
FromUtf8Error(#[from] std::string::FromUtf8Error),
}

#[enum_dispatch]
pub trait CommandExecutor {
fn execute(&self, backend: &Backend) -> RespFrame;
fn execute(self, backend: &Backend) -> RespFrame;
}

#[derive(Debug)]
#[enum_dispatch(CommandExecutor)]
pub enum Command {
Get(Get),
Set(Set),
HGet(HGet),
HSet(HSet),
HGetAll(HGetAll),

// unrecognized command
Unrecognized(Unrecognized),
}

#[allow(dead_code)]
#[derive(Debug)]
pub struct Get {
key: String,
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct Set {
key: String,
value: RespFrame,
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct HGet {
key: String,
field: String,
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct HSet {
key: String,
field: String,
value: RespFrame,
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct HGetAll {
key: String,
sort: bool,
}
#[derive(Debug)]
pub struct Unrecognized;

impl TryFrom<RespFrame> for Command {
type Error = CommandError;
fn try_from(v: RespFrame) -> Result<Self, Self::Error> {
match v {
RespFrame::Array(array) => array.try_into(),
_ => Err(CommandError::InvalidCommand(
"Command must be an array".to_string(),
)),
}
}
}

impl CommandExecutor for Unrecognized {
fn execute(self, _: &Backend) -> RespFrame {
// RespFrame::Error(SimpleError::new("Unrecognized command".to_string()))
info!("Unrecognized command");
RESP_OK.clone()
}
}

impl TryFrom<RespArray> for Command {
type Error = CommandError;
fn try_from(_frame: RespArray) -> Result<Self, Self::Error> {
todo!()
fn try_from(v: RespArray) -> Result<Self, Self::Error> {
let first = v.first();
match first {
Some(RespFrame::BulkString(ref cmd)) => {
let cmd_str = String::from_utf8_lossy(cmd.trim_ascii());
match cmd_str.to_ascii_lowercase().as_str() {
"get" => Ok(Get::try_from(v)?.into()),
"set" => Ok(Set::try_from(v)?.into()),
"hget" => Ok(HGet::try_from(v)?.into()),
"hset" => Ok(HSet::try_from(v)?.into()),
"hgetall" => Ok(HGetAll::try_from(v)?.into()),
_ => Ok(Unrecognized.into()),
}
}
_ => Err(CommandError::InvalidCommand(
"Command must have a BulkString as the first argument".to_string(),
)),
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod backend;
pub mod cmd;
pub mod network;
mod resp;

pub use backend::*;
Expand Down
31 changes: 29 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
fn main() {
println!("Hello, world!");
use anyhow::Result;
use tokio::net::TcpListener;
use tracing::{error, info};

use simple_redis::{network, Backend};

#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing library
tracing_subscriber::fmt::init();

let addr = "0.0.0.0:6379";
info!("Simple-Redis-Server is listening on {}", addr);

let listener = TcpListener::bind(addr).await?;
let backend = Backend::new();
loop {
let (stream, raddr) = listener.accept().await?;
info!("Accepted connection from: {}", raddr);
let cloned_backend = backend.clone(); // 克隆一个 backend 供子任务使用
tokio::spawn(async move {
match network::handle_connection(stream, cloned_backend).await {
Ok(_) => info!("Connection from {} exited", raddr),
Err(e) => error!("Error handling connection for {}: {:?}", raddr, e),
}
});
}
#[allow(unreachable_code)]
Ok(())
}
87 changes: 87 additions & 0 deletions src/network.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use anyhow::Result;
use bytes::BytesMut;
use futures::SinkExt;
use tokio::net::TcpStream;
use tokio_stream::StreamExt;
use tokio_util::codec::{Decoder, Encoder, Framed};
use tracing::info;

use crate::cmd::{Command, CommandExecutor};
use crate::{Backend, RespDecode, RespEncode, RespError, RespFrame};

#[derive(Debug)]
struct RespFrameCodec;

#[derive(Debug)]
struct RedisRequest {
frame: RespFrame,
backend: Backend,
}
#[derive(Debug)]
struct RedisResponse {
frame: RespFrame,
}

pub async fn handle_connection(stream: TcpStream, backend: Backend) -> Result<()> {
// how to get a frame from the stream
// call request_handler to handle the request
// send the response back to the stream
let mut framed = Framed::new(stream, RespFrameCodec);

loop {
let cloned_backend = backend.clone(); // Clone 一个 backend 供子任务使用
match framed.next().await {
Some(Ok(frame)) => {
info!("Received frame: {:?}", frame);
let request = RedisRequest {
frame,
backend: cloned_backend,
};
let response = request_handler(request).await?;
info!("Sending response: {:?}", response);
// 向 stream 发送响应
framed.send(response.frame).await?
}
Some(Err(err)) => return Err(err),
None => return Ok(()),
}
}
}

// 处理一个请求并返回响应
async fn request_handler(request: RedisRequest) -> Result<RedisResponse> {
let (frame, backend) = (request.frame, request.backend);
let cmd = Command::try_from(frame)?;
info!("Executing command: {:?}", cmd);
let frame = cmd.execute(&backend);
Ok(RedisResponse { frame })
}

impl Encoder<RespFrame> for RespFrameCodec {
type Error = anyhow::Error;

fn encode(
&mut self,
item: RespFrame,
dst: &mut BytesMut,
) -> std::result::Result<(), Self::Error> {
let encoded = item.encode();
dst.extend_from_slice(&encoded); // 转化成 bytes 并贝到 dst
Ok(())
}
}
impl Decoder for RespFrameCodec {
type Item = RespFrame;
type Error = anyhow::Error;

fn decode(
&mut self,
src: &mut BytesMut,
) -> std::result::Result<Option<Self::Item>, Self::Error> {
match RespFrame::decode(src) {
Ok(frame) => Ok(Some(frame)),
Err(RespError::NotComplete) => Ok(None),
Err(err) => Err(err.into()),
}
}
}
Loading

0 comments on commit fd8be0b

Please sign in to comment.