forked from keon/algorithms
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmatrix_inversion.py
124 lines (105 loc) · 4.07 KB
/
matrix_inversion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Inverts an invertible n x n matrix -- i.e., given an n x n matrix A, returns
an n x n matrix B such that AB = BA = In, the n x n identity matrix.
For a 2 x 2 matrix, inversion is simple using the cofactor equation. For
larger matrices, this is a four step process:
1. calculate the matrix of minors: create an n x n matrix by considering each
position in the original matrix in turn. Exclude the current row and column
and calculate the determinant of the remaining matrix, then place that value
in the current position's equivalent in the matrix of minors.
2. create the matrix of cofactors: take the matrix of minors and multiply
alternate values by -1 in a checkerboard pattern.
3. adjugate: hold the top left to bottom right diagonal constant, but swap all
other values over it.
4. multiply the adjugated matrix by 1 / the determinant of the original matrix
This code combines steps 1 and 2 into one method to reduce traversals of the
matrix.
Possible edge cases: will not work for 0x0 or 1x1 matrix, though these are
trivial to calculate without use of this file.
"""
import fractions
def invert_matrix(m):
"""invert an n x n matrix"""
# Error conditions
if not array_is_matrix(m):
print("Invalid matrix: array is not a matrix")
return [[-1]]
elif len(m) != len(m[0]):
print("Invalid matrix: matrix is not square")
return [[-2]]
elif len(m) < 2:
print("Invalid matrix: matrix is too small")
return [[-3]]
elif get_determinant(m) == 0:
print("Invalid matrix: matrix is square, but singular (determinant = 0)")
return [[-4]]
# Calculation
elif len(m) == 2:
# simple case
multiplier = 1 / get_determinant(m)
inverted = [[multiplier] * len(m) for n in range(len(m))]
inverted[0][1] = inverted[0][1] * -1 * m[0][1]
inverted[1][0] = inverted[1][0] * -1 * m[1][0]
inverted[0][0] = multiplier * m[1][1]
inverted[1][1] = multiplier * m[0][0]
return inverted
else:
"""some steps combined in helpers to reduce traversals"""
# get matrix of minors w/ "checkerboard" signs
m_of_minors = get_matrix_of_minors(m)
# calculate determinant (we need to know 1/det)
multiplier = fractions.Fraction(1, get_determinant(m))
# adjugate (swap on diagonals) and multiply by 1/det
inverted = transpose_and_multiply(m_of_minors, multiplier)
return inverted
def get_determinant(m):
"""recursively calculate the determinant of an n x n matrix, n >= 2"""
if len(m) == 2:
# trivial case
return (m[0][0] * m[1][1]) - (m[0][1] * m[1][0])
else:
sign = 1
det = 0
for i in range(len(m)):
det += sign * m[0][i] * get_determinant(get_minor(m, 0, i))
sign *= -1
return det
def get_matrix_of_minors(m):
"""get the matrix of minors and alternate signs"""
matrix_of_minors = [[0 for i in range(len(m))] for j in range(len(m))]
for row in range(len(m)):
for col in range(len(m[0])):
if (row + col) % 2 == 0:
sign = 1
else:
sign = -1
matrix_of_minors[row][col] = sign * get_determinant(get_minor(m, row, col))
return matrix_of_minors
def get_minor(m, row, col):
"""
get the minor of the matrix position m[row][col]
(all values m[r][c] where r != row and c != col)
"""
minors = []
for i in range(len(m)):
if i != row:
new_row = m[i][:col]
new_row.extend(m[i][col + 1:])
minors.append(new_row)
return minors
def transpose_and_multiply(m, multiplier=1):
"""swap values along diagonal, optionally adding multiplier"""
for row in range(len(m)):
for col in range(row + 1):
temp = m[row][col] * multiplier
m[row][col] = m[col][row] * multiplier
m[col][row] = temp
return m
def array_is_matrix(m):
if len(m) == 0:
return False
first_col = len(m[0])
for row in m:
if len(row) != first_col:
return False
return True