Skip to content

Commit

Permalink
added direct_sums
Browse files Browse the repository at this point in the history
  • Loading branch information
jakezweifler committed May 13, 2022
1 parent b2de1ab commit 6927616
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 11 deletions.
42 changes: 40 additions & 2 deletions Matrix.v
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ Definition kron {m n o p : nat} (A : Matrix m n) (B : Matrix o p) :
Matrix (m*o) (n*p) :=
fun x y => Cmult (A (x / o) (y / p)) (B (x mod o) (y mod p)).

Definition direct_sum {m n o p : nat} (A : Matrix m n) (B : Matrix o p) :
Matrix (m+o) (n+p) :=
fun x y => if (x <? m) || (y <? n) then A x y else B (x - m) (y - n).

Definition transpose {m n} (A : Matrix m n) : Matrix n m :=
fun x y => A y x.

Expand Down Expand Up @@ -227,6 +231,12 @@ Fixpoint Mmult_n n {m} (A : Square m) : Square m :=
| S n' => Mmult A (Mmult_n n' A)
end.

(** Direct sum of n copies of A *)
Fixpoint direct_sum_n n {m1 m2} (A : Matrix m1 m2) : Matrix (n*m1) (n*m2) :=
match n with
| 0 => @Zero 0 0
| S n' => direct_sum A (direct_sum_n n' A)
end.


(** * Showing that M is a vector space *)
Expand All @@ -249,14 +259,13 @@ Program Instance M_is_vector_space : forall n m, Vector_Space (Matrix n m) C :=
Solve All Obligations with program_simpl; prep_matrix_equality; lca.




(** Notations *)
Infix "∘" := dot (at level 40, left associativity) : matrix_scope.
Infix ".+" := Mplus (at level 50, left associativity) : matrix_scope.
Infix ".*" := scale (at level 40, left associativity) : matrix_scope.
Infix "×" := Mmult (at level 40, left associativity) : matrix_scope.
Infix "⊗" := kron (at level 40, left associativity) : matrix_scope.
Infix "⊕" := direct_sum (at level 20) : matrix_scope. (* should have different level and assoc *)
Infix "≡" := mat_equiv (at level 70) : matrix_scope.
Notation "A ⊤" := (transpose A) (at level 0) : matrix_scope.
Notation "A †" := (adjoint A) (at level 0) : matrix_scope.
Expand Down Expand Up @@ -427,6 +436,16 @@ Proof.
assumption.
Qed.

Lemma WF_direct_sum : forall {m n o p q r : nat} (A : Matrix m n) (B : Matrix o p),
q = m + o -> r = n + p ->
WF_Matrix A -> WF_Matrix B -> @WF_Matrix q r (A ⊕ B).
Proof.
unfold WF_Matrix, direct_sum.
intros; subst.
destruct H3; bdestruct_all; simpl; try apply H1; try apply H2.
all : lia.
Qed.

Lemma WF_transpose : forall {m n : nat} (A : Matrix m n),
WF_Matrix A -> WF_Matrix A⊤.
Proof. unfold WF_Matrix, transpose. intros m n A H x y H0. apply H.
Expand Down Expand Up @@ -484,6 +503,14 @@ Proof.
- apply WF_mult; assumption.
Qed.

Lemma WF_direct_sum_n : forall n {m1 m2} (A : Matrix m1 m2),
WF_Matrix A -> WF_Matrix (direct_sum_n n A).
Proof.
intros.
induction n; simpl.
- apply WF_Zero.
- apply WF_direct_sum; try lia; assumption.
Qed.

Lemma WF_Msum : forall d1 d2 n (f : nat -> Matrix d1 d2),
(forall i, (i < n)%nat -> WF_Matrix (f i)) ->
Expand Down Expand Up @@ -1256,6 +1283,17 @@ Lemma kron_mixed_product' : forall (m n n' o p q q' r mp nq or: nat)
(@kron m o p r (@Mmult m n o A C) (@Mmult p q r B D)).
Proof. intros. subst. apply kron_mixed_product. Qed.


Lemma direct_sum_assoc : forall {m n p q r s : nat}
(A : Matrix m n) (B : Matrix p q) (C : Matrix r s),
(A ⊕ B ⊕ C) = A ⊕ (B ⊕ C).
Proof. intros.
unfold direct_sum.
prep_matrix_equality.
bdestruct_all; simpl; auto.
repeat (apply f_equal_gen; try lia); easy.
Qed.

Lemma outer_product_eq : forall m (φ ψ : Matrix m 1),
φ = ψ -> outer_product φ φ = outer_product ψ ψ.
Proof. congruence. Qed.
Expand Down
12 changes: 6 additions & 6 deletions Prelim.v
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ Export ListNotations.
(** Boolean notation, lemmas *)

Notation "¬ b" := (negb b) (at level 75, right associativity). (* Level/associativity defined such that it does not clash with the standard library *)
Infix "⊕" := xorb (at level 20).
Infix "⊕" := xorb (at level 20).


Lemma xorb_nb_b : forall b, (¬ b) ⊕ b = true. Proof. destruct b; easy. Qed.
Lemma xorb_b_nb : forall b, b ⊕ (¬ b) = true. Proof. destruct b; easy. Qed.
Lemma xorb_nb_b : forall b, (¬ b) ⊕ b = true. Proof. destruct b; easy. Qed.
Lemma xorb_b_nb : forall b, b ⊕ (¬ b) = true. Proof. destruct b; easy. Qed.


Lemma xorb_involutive_l : forall b b', b ⊕ (b ⊕ b') = b'. Proof. destruct b, b'; easy. Qed.
Lemma xorb_involutive_r : forall b b', b ⊕ b' ⊕ b' = b. Proof. destruct b, b'; easy. Qed.
Lemma xorb_involutive_l : forall b b', b ⊕ (b ⊕ b') = b'. Proof. destruct b, b'; easy. Qed.
Lemma xorb_involutive_r : forall b b', b ⊕ b' ⊕ b' = b. Proof. destruct b, b'; easy. Qed.

Lemma andb_xorb_dist : forall b b1 b2, b && (b1 ⊕ b2) = (b && b1) ⊕ (b && b2).
Lemma andb_xorb_dist : forall b b1 b2, b && (b1 ⊕ b2) = (b && b1) ⊕ (b && b2).
Proof. destruct b, b1, b2; easy. Qed.

(** Nat lemmas *)
Expand Down
70 changes: 70 additions & 0 deletions Quantum.v
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,76 @@ Qed.
#[export] Hint Extern 2 (WF_Matrix (phase_shift _)) => apply WF_phase : wf_db.
#[export] Hint Extern 2 (WF_Matrix (control _)) => apply WF_control : wf_db.



(* how to make this proof shorter? *)
Lemma direct_sum_decomp : forall (m n o p : nat) (A B : Matrix m n),
WF_Matrix A -> WF_Matrix B ->
A ⊕ B = ∣0⟩⟨0∣ ⊗ A .+ ∣1⟩⟨1∣ ⊗ B.
Proof.
intros.
unfold direct_sum, kron, Mplus.
prep_matrix_equality.
bdestruct_all; try lia; simpl.
- repeat (rewrite Nat.div_small, Nat.mod_small; try easy); lca.
- rewrite H; auto.
destruct n. rewrite H, H0; try lca; try (right; lia).
rewrite (Nat.div_small x m), (Nat.mod_small x m); try easy.
replace (y / S n)%nat with (1 + (y - S n)/S n)%nat.
unfold Mmult, adjoint; simpl.
destruct (fst (Nat.divmod (y - S n) n 0 n)); try lca.
rewrite <- Nat.div_add_l; auto.
replace ((1 * S n + (y - S n)))%nat with y by lia; easy.
- rewrite H; auto.
destruct m. rewrite H, H0; try lca; try (left; lia).
rewrite (Nat.div_small y n), (Nat.mod_small y n); try easy.
replace (x / S m)%nat with (1 + (x - S m)/S m)%nat.
unfold Mmult, adjoint; simpl.
destruct (fst (Nat.divmod (x - S m) m 0 m)); try lca.
rewrite <- Nat.div_add_l; auto.
replace ((1 * S m + (x - S m)))%nat with x by lia; easy.
- destruct n; destruct m.
try (rewrite H, H0, H0; try lca);
try (left; lia); try (right; lia).
try (rewrite H, H0, H0; try lca);
try (left; lia); try (right; lia).
try (rewrite H, H0, H0; try lca);
try (left; lia); try (right; lia).
bdestruct (x - S m <? S m); bdestruct (y - S n <? S n).
replace (x / S m)%nat with 1%nat.
replace (y / S n)%nat with 1%nat.
replace (x mod S m) with (x - S m)%nat.
replace (y mod S n) with (y - S n)%nat.
lca.
replace y with ((y - S n) + 1*(S n))%nat by lia.
rewrite Nat.mod_add; try lia.
rewrite Nat.mod_small; lia.
replace x with ((x - S m) + 1*(S m))%nat by lia.
rewrite Nat.mod_add; try lia.
rewrite Nat.mod_small; lia.
replace y with ((y - S n) + 1*(S n))%nat by lia.
rewrite Nat.div_add; try lia.
rewrite Nat.div_small; lia.
replace x with ((x - S m) + 1*(S m))%nat by lia.
rewrite Nat.div_add; try lia.
rewrite Nat.div_small; lia.
rewrite H0; try (right; easy).
bdestruct (y <? 2*(S n)); try lia.
replace y with ((y - 2 * S n) + 2*(S n))%nat by lia.
rewrite Nat.div_add; try lia.
rewrite WF_braqubit0, WF_braqubit1; try lca; try (right; lia).
rewrite H0; try (left; easy).
bdestruct (x <? 2*(S m)); try lia.
replace x with ((x - 2 * S m) + 2*(S m))%nat by lia.
rewrite Nat.div_add; try lia.
rewrite WF_braqubit0, WF_braqubit1; try lca; try (left; lia).
rewrite H0; try (left; easy).
bdestruct (x <? 2*(S m)); try lia.
replace x with ((x - 2 * S m) + 2*(S m))%nat by lia.
rewrite Nat.div_add; try lia.
rewrite WF_braqubit0, WF_braqubit1; try lca; try (left; lia).
Qed.

(***************************)
(** Unitaries are unitary **)
(***************************)
Expand Down
6 changes: 3 additions & 3 deletions VectorStates.v
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ Qed.

Lemma f_to_vec_cnot : forall (n i j : nat) (f : nat -> bool),
i < n -> j < n -> i <> j ->
(pad_ctrl n i j σx) × (f_to_vec n f) = f_to_vec n (update f j (f j ⊕ f i)).
(pad_ctrl n i j σx) × (f_to_vec n f) = f_to_vec n (update f j (f j ⊕ f i)).
Proof.
intros.
unfold pad_ctrl, pad.
Expand Down Expand Up @@ -1117,8 +1117,8 @@ Proof.
repeat rewrite update_index_eq.
repeat rewrite update_index_neq by lia.
repeat rewrite update_index_eq.
replace ((f j ⊕ f i) ⊕ (f i ⊕ (f j ⊕ f i))) with (f i).
replace (f i ⊕ (f j ⊕ f i)) with (f j).
replace ((f j ⊕ f i) ⊕ (f i ⊕ (f j ⊕ f i))) with (f i).
replace (f i ⊕ (f j ⊕ f i)) with (f j).
rewrite update_twice_neq by auto.
rewrite update_twice_eq.
reflexivity.
Expand Down

0 comments on commit 6927616

Please sign in to comment.