From 79572e626997cea8096e0e809d11153c7f01e056 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 14 Jan 2025 16:01:10 -0800 Subject: [PATCH] lighthouse/quorum: avoid split brain and add shrink_only support (#71) --- proto/torchft.proto | 2 + src/lib.rs | 4 +- src/lighthouse.rs | 255 ++++++++++++++++++++++++++++++++++--- src/manager.rs | 19 ++- torchft/lighthouse_test.py | 34 ----- torchft/manager.py | 7 +- torchft/torchft.pyi | 1 + 7 files changed, 261 insertions(+), 61 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index e84855c..67a42c0 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -41,6 +41,7 @@ message QuorumMember { string store_address = 3; int64 step = 4; uint64 world_size = 5; + bool shrink_only = 6; } message Quorum { @@ -72,6 +73,7 @@ message ManagerQuorumRequest { int64 rank = 1; int64 step = 2; string checkpoint_server_addr = 3; + bool shrink_only = 4; } message ManagerQuorumResponse { diff --git a/src/lib.rs b/src/lib.rs index 199d4cf..8923682 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,13 +105,14 @@ impl ManagerClient { }) } - #[pyo3(signature = (rank, step, checkpoint_server_addr, timeout=None))] + #[pyo3(signature = (rank, step, checkpoint_server_addr, shrink_only, timeout=None))] fn quorum( &mut self, py: Python<'_>, rank: i64, step: i64, checkpoint_server_addr: String, + shrink_only: bool, timeout: Option, ) -> Result<(i64, i64, i64, String, String, i64, Option, i64, bool), StatusError> { py.allow_threads(move || { @@ -119,6 +120,7 @@ impl ManagerClient { rank: rank, step: step, checkpoint_server_addr: checkpoint_server_addr, + shrink_only: shrink_only, }); // This notifies the server about the timeout but doesn't affect the // endpoint timeout which we set on client creation. diff --git a/src/lighthouse.rs b/src/lighthouse.rs index c70bb74..420bb3c 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -6,6 +6,7 @@ use core::net::SocketAddr; use std::collections::HashMap; +use std::collections::HashSet; use std::sync::Arc; use std::time::Duration; use std::time::{Instant, SystemTime}; @@ -115,21 +116,23 @@ fn quorum_compute( opt: &LighthouseOpt, ) -> (Option>, String) { let heartbeats = &state.heartbeats; - let healthy_participants: HashMap = state - .participants - .clone() - .into_iter() - .filter(|(replica_id, _details)| { - let last_heartbeat = heartbeats.get(replica_id); - if last_heartbeat.is_none() { - return false; + let healthy_replicas: HashSet<&String> = heartbeats + .iter() + .filter_map(|(replica_id, last_heartbeat)| { + if now.duration_since(*last_heartbeat) < Duration::from_millis(opt.heartbeat_timeout_ms) + { + return Some(replica_id); } - - now.duration_since(*last_heartbeat.unwrap()) - < Duration::from_millis(opt.heartbeat_timeout_ms) + None }) .collect(); + let healthy_participants: HashMap<&String, &QuorumMemberDetails> = state + .participants + .iter() + .filter(|(replica_id, _details)| healthy_replicas.contains(replica_id)) + .collect(); + let mut candidate_participants: Vec = healthy_participants .values() .map(|details| details.member.clone()) @@ -138,16 +141,35 @@ fn quorum_compute( // Sort by replica ID to get a consistent ordering across runs. candidate_participants.sort_by_key(|p| p.replica_id.clone()); + let shrink_only = healthy_participants + .iter() + .any(|(_, details)| details.member.shrink_only); + let metadata = format!( - "[{}/{} participants healthy]", + "[{}/{} participants healthy][shrink_only={}]", healthy_participants.len(), - state.participants.len() + state.participants.len(), + shrink_only, ); // Check if we can use the previous quorum. + // TODO: do we still need this given we have heartbeats? if state.prev_quorum.is_some() { let prev_quorum = state.prev_quorum.as_ref().unwrap(); + let prev_replica_ids: HashSet<&String> = prev_quorum + .participants + .iter() + .map(|p| &p.replica_id) + .collect(); + + if shrink_only { + candidate_participants = candidate_participants + .into_iter() + .filter(|p| prev_replica_ids.contains(&p.replica_id)) + .collect(); + } + // Fast quorum is when all previous participants are still in the quorum // and we have enough participants to form a quorum. let is_fast_quorum = prev_quorum @@ -163,11 +185,12 @@ fn quorum_compute( } } + // Minimum quorum size check. if healthy_participants.len() < opt.min_replicas as usize { return ( None, format!( - "No quorum, only have {} participants, need {} {}", + "No quorum, only have {} participants, need min_replicas {} {}", healthy_participants.len(), opt.min_replicas, metadata @@ -175,18 +198,36 @@ fn quorum_compute( ); } + // Avoid split brain by requiring at least half of the known alive workers. + if healthy_participants.len() <= healthy_replicas.len() / 2 { + return ( + None, + format!( + "No quorum, only have {} participants, need at least half of {} healthy workers {}", + healthy_participants.len(), + healthy_replicas.len(), + metadata + ), + ); + } + + let all_healthy_joined = healthy_participants.len() == healthy_replicas.len(); + // Quorum is valid at this point but lets wait for stragglers. let first_joined = healthy_participants .values() .map(|details| details.joined) .min() .unwrap_or(now); - if now.duration_since(first_joined) < Duration::from_millis(opt.join_timeout_ms) { + if !all_healthy_joined + && now.duration_since(first_joined) < Duration::from_millis(opt.join_timeout_ms) + { return ( None, format!( - "Valid quorum with {} participants, waiting for stragglers due to join timeout {}", + "Valid quorum with {} participants, waiting for {} healthy but not participating stragglers due to join timeout {}", healthy_participants.len(), + healthy_replicas.len() - healthy_participants.len(), metadata ), ); @@ -546,17 +587,43 @@ mod tests { store_address: "".to_string(), step: 1, world_size: 1, + shrink_only: false, }, }, ); state.heartbeats.insert("a".to_string(), now); - assert!(!quorum_compute(now, &state, &opt).0.is_some()); + state.participants.insert( + "b".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "b".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + }, + ); + state.heartbeats.insert("b".to_string(), now); + + // all healthy workers participating + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + + // add healthy worker but not participating + state.heartbeats.insert("c".to_string(), now); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_none(), "{}", reason); + assert!(reason.contains("join timeout"), "{}", reason); + // increase elapsed time to pass join timeout state.participants.get_mut("a").unwrap().joined = now.sub(Duration::from_secs(10 * 60 * 60)); - - assert!(quorum_compute(now, &state, &opt).0.is_some()); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); Ok(()) } @@ -591,6 +658,7 @@ mod tests { store_address: "".to_string(), step: 1, world_size: 1, + shrink_only: false, }, }, ); @@ -617,6 +685,7 @@ mod tests { store_address: "".to_string(), step: 1, world_size: 1, + shrink_only: false, }, }, ); @@ -662,12 +731,17 @@ mod tests { store_address: "".to_string(), step: 1, world_size: 1, + shrink_only: false, }, }, ); state.heartbeats.insert("a".to_string(), now); - assert!(!quorum_compute(now, &state, &opt).0.is_some()); + // Not proceeding since one worker is alive but not participating + state.heartbeats.insert("b".to_string(), now); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_none(), "{}", reason); + assert!(reason.contains("need at least half"), "{}", reason); state.prev_quorum = Some(Quorum { quorum_id: 1, @@ -677,6 +751,7 @@ mod tests { store_address: "".to_string(), step: 1, world_size: 1, + shrink_only: false, }], created: Some(SystemTime::now().into()), }); @@ -694,6 +769,7 @@ mod tests { store_address: "".to_string(), step: 1, world_size: 1, + shrink_only: false, }, }, ); @@ -707,6 +783,92 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_quorum_shrink_only() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, // 1hr + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + }; + + let mut state = State { + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 0, + heartbeats: HashMap::new(), + }; + + let now = Instant::now(); + + state.prev_quorum = Some(Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "a".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "b".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + ], + created: Some(SystemTime::now().into()), + }); + + state.participants.insert( + "a".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "a".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + shrink_only: true, + }, + }, + ); + state.heartbeats.insert("a".to_string(), now); + + // insert particpant that was not in prev quorum + state.participants.insert( + "c".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "c".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + shrink_only: true, + }, + }, + ); + state.heartbeats.insert("c".to_string(), now); + + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + + let quorum = quorum_met.unwrap(); + assert!(quorum.len() == 1); + assert!(quorum[0].replica_id == "a"); + + Ok(()) + } + #[tokio::test] async fn test_lighthouse_e2e() -> Result<()> { let opt = LighthouseOpt { @@ -738,6 +900,7 @@ mod tests { store_address: "".to_string(), step: 10, world_size: 1, + shrink_only: false, }), }); @@ -750,6 +913,55 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_quorum_split_brain() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, // 1hr + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + }; + + let mut state = State { + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 0, + heartbeats: HashMap::new(), + }; + + let now = Instant::now(); + + assert!(!quorum_compute(now, &state, &opt).0.is_some()); + + state.participants.insert( + "a".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "a".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + }, + ); + state.heartbeats.insert("a".to_string(), now); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + + // Not proceeding since one worker is alive but not participating + state.heartbeats.insert("b".to_string(), now); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_none(), "{}", reason); + assert!(reason.contains("at least half"), "{}", reason); + + Ok(()) + } + #[tokio::test] async fn test_quorum_changed() { let a = vec![QuorumMember { @@ -758,6 +970,7 @@ mod tests { store_address: "".to_string(), step: 1, world_size: 1, + shrink_only: false, }]; let b = vec![QuorumMember { replica_id: "1".to_string(), @@ -765,6 +978,7 @@ mod tests { store_address: "changed".to_string(), step: 1000, world_size: 1, + shrink_only: false, }]; // replica_id is the same @@ -776,6 +990,7 @@ mod tests { store_address: "".to_string(), step: 1, world_size: 1, + shrink_only: false, }]; // replica_id changed assert!(quorum_changed(&a, &c)); diff --git a/src/manager.rs b/src/manager.rs index a51bbe9..76efb58 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -213,6 +213,7 @@ impl ManagerService for Arc { store_address: self.store_address.clone(), step: req.step, world_size: self.world_size, + shrink_only: req.shrink_only, }), }); @@ -250,12 +251,18 @@ impl ManagerService for Arc { let mut participants = quorum.participants.clone(); participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); - let mut replica_rank = 10000000000; - for (i, p) in participants.iter().enumerate() { + let replica_rank = participants.iter().enumerate().find_map(|(i, p)| { if p.replica_id == self.replica_id { - replica_rank = i; - break; + Some(i) + } else { + None } + }); + if replica_rank.is_none() { + return Err(Status::not_found(format!( + "replica {} not participating in returned quorum", + self.replica_id + ))); } let max_step = participants.iter().map(|p| p.step).max().unwrap(); @@ -291,7 +298,7 @@ impl ManagerService for Arc { max_step: max_step, max_rank: max_rank, max_world_size: max_participants.len() as i64, - replica_rank: replica_rank as i64, + replica_rank: replica_rank.unwrap() as i64, replica_world_size: participants.len() as i64, heal: heal, }; @@ -469,6 +476,7 @@ mod tests { rank: 0, step: 123, checkpoint_server_addr: "addr".to_string(), + shrink_only: false, }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -526,6 +534,7 @@ mod tests { rank: 0, step: 0, checkpoint_server_addr: "addr".to_string(), + shrink_only: false, }); request.set_timeout(Duration::from_secs(10)); diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index 36ab62c..38700b6 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -58,40 +58,6 @@ def test_join_timeout_behavior(self) -> None: join_timeout_ms=400, ) - # Create a manager that tries to join - try: - store = dist.TCPStore( - host_name="localhost", - port=0, - is_master=True, - wait_for_workers=False, - ) - pg = ProcessGroupGloo() - manager = Manager( - pg=pg, - min_replica_size=1, - load_state_dict=lambda x: None, - state_dict=lambda: None, - replica_id=f"lighthouse_test", - store_addr="localhost", - store_port=store.port, - rank=0, - world_size=1, - use_async_quorum=False, - lighthouse_addr=lighthouse.address(), - ) - - start_time = time.time() - manager.start_quorum() - time_taken = time.time() - start_time - assert time_taken > 0.4, f"Time taken to join: {time_taken} < 0.4s" - - finally: - # Cleanup - lighthouse.shutdown() - if "manager" in locals(): - manager.shutdown() - def test_heartbeat_timeout_ms_sanity(self) -> None: lighthouse = Lighthouse( bind="[::]:0", diff --git a/torchft/manager.py b/torchft/manager.py index c75bb48..9bc4ba1 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -338,6 +338,7 @@ def callback( def start_quorum( self, allow_heal: bool = True, + shrink_only: bool = False, timeout: Optional[timedelta] = None, ) -> None: """ @@ -372,6 +373,7 @@ def start_quorum( self._quorum_future = self._executor.submit( self._async_quorum, allow_heal=allow_heal, + shrink_only=shrink_only, timeout=timeout or self._timeout, ) if not self._use_async_quorum: @@ -396,7 +398,9 @@ def wait_quorum(self) -> None: ), "must call start_quorum before wait_quorum" self._quorum_future.result() - def _async_quorum(self, allow_heal: bool, timeout: timedelta) -> None: + def _async_quorum( + self, allow_heal: bool, shrink_only: bool, timeout: timedelta + ) -> None: ( quorum_id, replica_rank, @@ -411,6 +415,7 @@ def _async_quorum(self, allow_heal: bool, timeout: timedelta) -> None: rank=self._rank, step=self._step, checkpoint_server_addr=self._ckpt_server.address(), + shrink_only=shrink_only, timeout=timeout, ) diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index a5196ea..a694920 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -8,6 +8,7 @@ class ManagerClient: rank: int, step: int, checkpoint_server_addr: str, + shrink_only: bool, timeout: Optional[timedelta] = None, ) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ... def checkpoint_address(