diff --git a/pymysqlreplication/gtid.py b/pymysqlreplication/gtid.py index 16d02a7c..2d149a4a 100644 --- a/pymysqlreplication/gtid.py +++ b/pymysqlreplication/gtid.py @@ -287,6 +287,8 @@ def merge_gtid(self, gtid): self.gtids = new_gtids def __contains__(self, other): + if isinstance(other, GtidSet): + return all(other_gtid in self.gtids for other_gtid in other.gtids) if isinstance(other, Gtid): return any(other in x for x in self.gtids) raise NotImplementedError @@ -296,6 +298,13 @@ def __add__(self, other): new = GtidSet(self.gtids) new.merge_gtid(other) return new + + if isinstance(other, GtidSet): + new = GtidSet(self.gtids) + for gtid in other.gtids: + new.merge_gtid(gtid) + return new + raise NotImplementedError def __str__(self):