diff --git a/java/p211.java b/java/p211.java index d5083ff8183f24ec74c519b14e8dca90e2c6a10a..3903bb30003352cb1e33fe0e87d2246215c6ad9b 100644 --- a/java/p211.java +++ b/java/p211.java @@ -34,8 +34,9 @@ public final class p211 implements EulerSolution { } long sum = 0; + SquareTester sqt = new SquareTester(3 * 5 * 7 * 11 * 13 * 17); for (int i = 1; i < sigma2.length; i++) { - if (isPerfectSquare(sigma2[i])) + if (sqt.isPerfectSquare(sigma2[i])) sum += i; } return Long.toString(sum); @@ -57,14 +58,45 @@ public final class p211 implements EulerSolution { } - private static boolean isPerfectSquare(long x) { - long y = 0; - for (long i = 1L << 31; i != 0; i >>>= 1) { - y |= i; - if (y > 3037000499L || y * y > x) - y ^= i; + + // Consider the set of all squared natural numbers, i.e. {0, 1, 4, 9, 16, 25, ...}. + // When this set is viewed modulo some number n, usually not every residue is in the set. + // For example, all squares modulo 3 is {0, 1} - so a perfect square modulo 3 is never 2. + // By choosing a suitably large modulus, we can . + private static final class SquareTester { + + // isResidue[i] is true iff there exists a natural number k such that k^2 = i mod modulus. + // Hence for any k, if isResidue[k] is false then k is not a perfect square. + private boolean[] isResidue; + + + // Any product of unique small prime numbers excluding 2 makes a good modulus + // that leads to fast tests. But the behavior is correct for any modulus >= 1. + public SquareTester(int modulus) { + if (modulus < 1) + throw new IllegalArgumentException(); + isResidue = new boolean[modulus]; + for (int i = 0; i < modulus; i++) + isResidue[(int)((long)i * i % modulus)] = true; } - return y * y == x; + + + public boolean isPerfectSquare(long x) { + // Reject many but not all numbers that aren't a perfect square. + // This speed optimization can be omitted without affecting correctness. + if (!isResidue[(int)(x % isResidue.length)]) + return false; + + // A complete algorithm for detecting squares + long y = 0; + for (long i = 1L << 31; i != 0; i >>>= 1) { + y |= i; + if (y > 3037000499L || y * y > x) + y ^= i; + } + return y * y == x; + } + } }