Skip to content

Commit

Permalink
feat: fix false AKS positives
Browse files Browse the repository at this point in the history
  • Loading branch information
h5law committed May 22, 2024
1 parent 238e7a7 commit 48ec864
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 87 deletions.
195 changes: 110 additions & 85 deletions aks.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package primality

import (
"fmt"
"math"
"slices"
)
Expand All @@ -26,38 +27,53 @@ func fastPowerMod(b, e, m int) int {
}
for e > 0 {
if e%2 == 1 {
r *= b % m
r = (r * b) % m
}
b *= b % m
b = (b * b) % m
e >>= 1
}
return r

}

// Check n != a^b for a,b > 1, returning true if it is otherwise false if not.
func basePowerCheck(n int) bool {
for i := 2.0; i < math.Log2(float64(n)); i++ {
a := math.Pow(float64(n), 1.0/float64(i))
// Check if the rounded value a is equal to its integer value.
if i == a {
return true // n is a composite number.
var bMax = int(math.Log2(float64(n))) + 1
for b := 2; b <= bMax; b++ {
var aMin = 2
var aMax = int(math.Pow(float64(2), float64(64)/float64(b))) - 1
if fastPower(b, aMin) == n {
return true
}
if fastPower(b, aMax) == n {
return true
}
// binary search for a base a for exponent b
for aMax-aMin != 1 {
var aMid = (aMin + aMax) / 2
var nMid = fastPower(aMid, b)
if nMid == n {
return true
} else if nMid < n {
aMin = aMid
} else { // nMid > n
aMax = aMid
}
}
}
return false
}

// orderMod finds the order of coprimes a and r such that r^a % i = 1,
// orderMod finds the order of coprimes a and r such that a^i % r = 1,
// where i is the order found
func orderMod(a, r int) int {
// Check if a and r are coprime
g := gcd(a, r)
if g != 1 {
if gcd(a, r) != 1 {
return 0
}
// Find the order i such that r^a % i = 1
i := 1
// Find the order i such that a^i % r = 1
i := 0
for {
fmt.Println(i)
if fastPowerMod(a, i, r) == 1 {
return i
}
Expand All @@ -67,28 +83,25 @@ func orderMod(a, r int) int {

// Find the smallest order such that Ord_r(n) > (log_2(n))^2
func findCorrectOrder(n int) int {
// fn := float64(n)
// maxK := int(math.Floor(math.Log2(fn)*math.Log2(fn) + 0.5))
// r := 2
// for {
// if orderMod(n, r) > maxK {
// return r
// }
// r++
// }
fn := float64(n)
maxK := math.Floor(math.Log2(fn)*math.Log2(fn) + 0.5)
maxR := max(3.0, math.Pow(math.Log2(fn), 5))
nextR := true
var i float64
for i = 2.0; nextR && i < maxR; i++ {
nextR = false
for j := 1; !nextR && j <= int(maxK); j++ {
nextR = fastPowerMod(n, j, int(i)) == 1 ||
fastPowerMod(n, j, int(i)) == 0
lower := int(math.Floor(math.Log2(fn)*math.Log2(fn) + 0.5))
upperR := max(3, int(math.Floor(math.Pow(math.Log2(fn), 5)+0.5)))
for r := 1; r < upperR; r++ {
if gcd(r, n) != 1 {
continue
}
k := 2
for {
if fastPowerMod(r, k, n) == 1 {
break
}
k++
}
if r > lower && r != n {
return r - 1
}
}
return int(i - 1.0)
return 0
}

// gcd finds the greatest common divisor of a and b
Expand Down Expand Up @@ -138,10 +151,14 @@ func polynomialMod(p []int, m int) []int {
r := make([]int, len(p))
for i, x := range p {
mod := x % m
if mod < 0 {
mod += x
}
if mod != 0 {
r[i] = mod
}
}
stripTrailingZeros(&r)
return r
}

Expand All @@ -162,68 +179,82 @@ func polynomialExpansion(e, a int) []int {
return c
}

// degree determines the degree of the polynomial p
//
// Ref: https://rosettacode.org/wiki/Polynomial_long_division#Go
func stripTrailingZeros(p *[]int) {
for i := len(*p) - 1; i >= 0; i-- {
if (*p)[i] != 0 {
*p = (*p)[:i+1]
return
}
}
*p = nil
}

func degree(p []int) int {
for d := len(p) - 1; d >= 0; d-- {
if p[d] != 0 {
return d
for i := len(p) - 1; i >= 0; i-- {
if i != 0 {
return i
}
}
return -1
}

// pld performs polynomial long division on the two polynomial coefficient slices
// provided. It expects the polynomials to be in ascending order of powers of x.
//
// Ref: https://rosettacode.org/wiki/Polynomial_long_division#Go
func pld(nn, dd []int) (q, r []int, ok bool) {
dnn := degree(nn)
ddd := degree(dd)
if ddd < 0 {
return
func polynomialMultiplyScalar(p []int, n int) []int {
res := make([]int, len(p))
for i := range p {
res[i] = p[i] * n
}
nn = append(r, nn...)
if dnn >= ddd {
q = make([]int, dnn-ddd+1)
for dnn >= ddd {
d := make([]int, dnn+1)
copy(d[dnn-ddd:], dd)
q[dnn-ddd] = nn[dnn] / d[degree((d))]
for i := range d {
d[i] *= q[dnn-ddd]
nn[i] -= d[i]
}
dnn = degree(nn)
}
return res
}

func polynomialRemainder(p1, p2 []int) []int {
dp1 := degree(p1)
dp2 := degree(p2)
if dp2 < 0 {
return nil
}
if dp1 < dp2 {
return p1
}
q := make([]int, dp1)
for dp1 >= dp2 {
d := make([]int, dp1+1)
copy(d[dp1-dp2:], p2)
q[dp1-dp2] = p1[dp1] / d[len(d)-1]
d = polynomialMultiplyScalar(d, q[dp1-dp2])
p1 = polynomialSubtraction(p1, d)
dp1 = degree(p1)
}
return q, nn, true
return p1
}

// polynomialModRemainder finds the remainder of polynomials p1/p2 and does a
// term-wise reduction modulo m on the result, returning a slice of coefficients
// for a polynomial in ascending order of x.
func polynomialModRemainder(p1, p2 []int, m int) []int {
_, r, ok := pld(p1, p2)
if !ok {
return nil
}
return polynomialMod(r, m)
return polynomialMod(polynomialRemainder(p1, p2), m)
}

// polynomialSubtraction subtracts p1 from p2 in a term-wise fashion.
// The function orders the polynomials longest first prior to subtraction.
func polynomialSubtraction(p1, p2 []int) []int {
longest, shortest := p1, p2
if len(p2) > len(p2) {
longest, shortest = p2, p1
res := make([]int, len(p1))
i := 0
for len(p2) > 0 && len(p1) > 0 {
res[i] = p1[0] - p2[0]
p1 = p1[1:]
p2 = p2[1:]
i++
}
res := make([]int, len(longest)+1)
for i, x := range shortest {
res[i] = longest[i] - x
stripTrailingZeros(&res)
if len(p1) > 0 {
res = append(res, p1...)
}
copy(res[len(shortest):], longest[len(shortest)-1:])
if len(p2) > 0 {
for _, x := range p2 {
res = append(res, -x)
}
}
stripTrailingZeros(&res)
return res
}

Expand Down Expand Up @@ -263,20 +294,14 @@ func AKS(n int) bool {
)
xna := polynomialExpansion(n, 0)
xr1 := polynomialExpansion(r, 0)
xr1[0]--
var xa, remA, remB []int // stop throwing away each loop
for a := 1; a <= maxA; a++ {
xna[0] += a
xa := polynomialExpansion(n, a)
remA := polynomialModRemainder(xa, xr1, n)
_, remB, ok := pld(xna, xr1)
if !ok {
panic("error dividing polynomials")
}
longest := make([]int, len(remA))
if len(remB) > len(remA) {
longest = make([]int, len(remB))
}
if slices.Equal(polynomialSubtraction(remA, remB), longest) {
xa = polynomialExpansion(n, a)
remA = polynomialModRemainder(xa, xr1, n)
remB = polynomialRemainder(xna, xr1)
// fmt.Println(remA, remB)
if slices.Equal(remA, remB) {
return false
}
}
Expand Down
5 changes: 3 additions & 2 deletions primes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"math/big"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -34,15 +35,15 @@ func TestPrimalityTests_Equal(t *testing.T) {
primes1 = append(primes1, i)
}
}
require.Len(t, primes1, 168)
assert.Len(t, primes1, 168)

primes2 := make([]int, 0, 168)
for i := 1; i <= 1000; i++ {
if aks := AKS(i); aks {
primes2 = append(primes2, i)
}
}
require.Len(t, primes2, 168)
assert.Len(t, primes2, 168)

require.Equal(t, primes1, primes2)

Expand Down
73 changes: 73 additions & 0 deletions units_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package primality

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAKSSteps_One(t *testing.T) {
composite := basePowerCheck(31) // prime
require.False(t, composite)
composite = basePowerCheck(49) // 7^2
require.True(t, composite)
composite = basePowerCheck(121) // 11^2
require.True(t, composite)
}

func TestAKSSteps_Two(t *testing.T) {
assert.Equal(t, 8, findCorrectOrder(7))
assert.Equal(t, 25, findCorrectOrder(31))
}

func TestAKSSteps_FastPowerMod(t *testing.T) {
require.Equal(t, 1, fastPowerMod(4, 3, 7))
require.Equal(t, 0, fastPowerMod(4, 4, 8))
require.Equal(t, 27, fastPowerMod(29, 17, 31))
}

func TestAKSSteps_PolynomialExpansion(t *testing.T) {
p1 := polynomialExpansion(3, 0)
require.Equal(t, []int{1, 3, 3, 1}, p1)

p2 := polynomialExpansion(4, 2)
require.Equal(t, []int{16, 32, 24, 8, 1}, p2)

p3 := polynomialExpansion(5, -3)
require.Equal(t, []int{-243, 405, -270, 90, -15, 1}, p3)
}

func TestAKSSteps_PolynomialRemainder(t *testing.T) {
p1 := polynomialExpansion(6, 3)
p2 := polynomialExpansion(5, 2)
rem := polynomialRemainder(p1, p2)
require.Equal(t, []int{473, 786, 495, 140, 15}, rem)

p1 = polynomialExpansion(5, 3)
p2 = polynomialExpansion(3, -1)
rem = polynomialRemainder(p1, p2)
require.Equal(t, []int{384, 0, 640}, rem)
}

func TestAKSSteps_PolynomialSubtraction(t *testing.T) {
p1 := polynomialExpansion(6, 3)
p2 := polynomialExpansion(5, 2)
sub := polynomialSubtraction(p1, p2)
require.Equal(t, []int{697, 1378, 1135, 500, 125, 17, 1}, sub)

sub = polynomialSubtraction(p2, p1)
require.Equal(t, []int{-697, -1378, -1135, -500, -125, -17, -1}, sub)
}

func TestAKSSteps_PolymomialScalarMultiply(t *testing.T) {
p1 := polynomialExpansion(4, 2)
scalar := 3
mul := polynomialMultiplyScalar(p1, scalar)
require.Equal(t, []int{48, 96, 72, 24, 3}, mul)

p2 := polynomialExpansion(3, -2)
scalar = 6
mul = polynomialMultiplyScalar(p2, scalar)
require.Equal(t, []int{-48, 72, -36, 6}, mul)
}

0 comments on commit 48ec864

Please sign in to comment.