FFT in competitive programming

2015. 7. 18. 02:15알고리즘 문제풀기/기타 주제

다항식 찾기

(n+1)개의 점, 즉 $(x_i, y_i)$의 순서쌍이 (n+1)개 주어지면, 이 n개의 점을 모두 지나는 다항식을 하나 찾을 수 있다. 즉 $P(x_i) = y_i$를 모두 만족하는 $P(x)$를 찾을 수 있다. (일단 각각의 $x_i$는 모두 다르다고 하자.)
예를 들어 (1, 1)과 (2, 3)을 지나는 다항식은 $f(x)=2x-1$도 있고, $g(x)=x^2-x+1$도 있다.
이런 다항식은 여러 가지가 있고, 무한히 많다.
그런데 사실은, 그중 차수가 가장 낮은 건 n차이고, 게다가 이때의 n차 다항식은 무려 유일하다.

다항식을 왜 찾을까? ― 다항식 곱셈 문제

잠깐 다른 얘기를 해본다.
두 다항식의 곱을 계산하는 문제를 생각하자. 예를 들면 $(x^2 + x + 3)$과 $(x - 2)$의 곱은
$$(x^2 + x + 3)(x - 2) = x^3 - x^2 + x - 6$$
우리가 알고 있는 그 다항식 곱셈이다. 단순히 전개해서 정리하면 되는 그것이다.
컴퓨터로 짠다면 어떻게 할까? 왼쪽 다항식의 i차 항(=$c x^i$)에 오른쪽 다항식의 j차 항(=$d x^j$)이 곱해지면 최종 다항식의 (i+j)차 항에 계수의 곱(=$c \cdot d$)이 더해진다.
따라서 답을 저장할 배열을 만들어두고 왼쪽과 오른쪽의 모든 항을 돌면서 더해주면 된다.

for(int i = 0; i <= n; ++i)
    for(int j = 0; j <= m; ++j)
        c[i+j] += a[i] * b[j];

그런데 다항식 곱셈을 빨리 하고 싶다는 요구가 있었다. 이것보다 빨리? 지금은 느린가?
한 쪽이 n차, 다른 쪽이 m차 다항식이라면 우리의 방법은 위의 코드에서도 알 수 있듯 $O(nm)$ 시간이 걸린다.
물론 n과 m이 10만 단위나 그 이상이라면 (PS적 의미에서) 시간 내에 작동하기 어렵다.

하지만 다항식 곱셈의 정의를 생각해보면 우리는 정말 정직하게 꼭 필요한 작업만 수행한 것 뿐인데도 느린 셈이다.
똑같은 계산이 여러 번 나타난다면 또 모르나 여기서는 모든 계산이 독립적이다.
그럼 이걸 줄이는 게 가능키나 한가? 막막하다.

이때 이런 생각을 해본다. 여기가 핵심이다.
몇 개의 x 값에 대해 그걸 대입한 결과인 A(x)와 B(x)를 계산할 수 있다면,
둘을 곱한 다항식 $C(x)=A(x)B(x)$는 $(x_i, A(x_i) B(x_i))$를 모두 지날 것이다.

A가 n차, B가 m차라고 하면, 이런 x를 (n+m+1)개 잡아서 A와 B값을 곱해 점을 총 (n+m+1)개 만들면,
다항식 C(x)가 유일하게 결정될 것이다.

그런데 (n+m+1)개의 x를 A와 B에 각각 대입하는 과정에서 다시 문제가 생긴다.
하나의 x값을 단순하게 대입하면 각각 $O(n)$과 $O(m)$의 시간이 걸린다. $a_i x^i$를 모두 계산해서 합해야 하기 때문이다.
그러면 걸리는 총 시간은 $O(n^2 + m^2)$ 쯤 되므로, 우리의 목표인 다항식 곱셈을 제곱 미만에 하기에는 못 미치는 결과이다.

아래 두 가지를 빠르게, 제곱 미만에 할 수 있다면 우리의 문제가 해결될 텐데...

  • (n+m+1) 개의 서로 다른 x값을 A와 B에 대입한 결과를 각각 알기
  • (x, y)의 쌍이 (n+m+1)개 주어지면 이를 모두 지나는 (n+m)차 다항식 얻기

Fourier transform

이 두 가지가 모두 제곱 미만의 시간에 해결 가능하며 그 방법을 Fast Fourier transform으로 일컫는다.
이 글에서는 그 방법을 소개한다.

먼저 A와 B의 차수가 같으며 $2^t - 1$이라고 가정한다. 만일 원래 차수가 이보다 적으면 계수가 0인 항들을 추가하면 된다.
이제 $n=2^t$으로 표기한다. A와 B의 차수는 (n-1)이다.

갑자기 다음과 같이 $\omega$를 정의한다.
$$ \omega_{n} = \textrm{e} ^ {2 \pi i / n} = \cos \left(\frac{2 \pi}{n} \right) + i \sin \left( \frac{2 \pi}{n} \right) $$
여기서 $i=\sqrt{-1}$이다.
이제 우리가 대입할 x값들은 다음과 같다.
$$ x_k = \omega_n ^k = \textrm{e} ^ {2 \pi i k / n} = \cos \left(\frac{2 \pi k}{n} \right) + i \sin \left( \frac{2 \pi k}{n} \right) $$

처음 봤을 땐 정말 "왜?" 이 생각밖에 안 든다. 이건 시간이 지날 수록 이해가 깊어지는 개념이라 어쩔 수 없다.
그 전에 잠깐 몇 가지 특징을 짚고 넘어가자.
$\omega_n$은 이런 성질이 있다. 읽기 편하게 잠깐 아래첨자를 떼고 $\omega$로 쓴다.
$$\begin{array}{l} \omega^0 = \omega^n = 1 \\ \omega^k \neq 1\textrm{ for all }k \neq 0, n, 2n, \cdots \\ \omega^{n/2} = -1 \end{array}$$
복소평면 상에서 편각이 $2 \pi / n$이고 크기가 1이므로, 단위원(반지름이 1)을 n조각으로 나누는 모양을 가지기도 한다.

root of unity, from Wikipedia

위의 그림은 n=5일 때의 경우이므로, 앞으로 할 논의와 일부 맞지 않는 부분이 있다. 느낌만 참고하자.

이제 $A(x_k)$를 구하는 방법을 생각해보자.
$$ A(x_k) = a_0 x_k^0 + a_1 x_k^1 + \cdots + a_{n-1} x_k^{n-1} $$
여기서 갑자기 홀수차 항과 짝수차 항을 분리한다.

$$\begin{array}{l}
= (a_0 x_k^0 + a_2 x_k^2 + \cdots + a_{n-2} x_k^{n-2}) +
(a_1 x_k^1 + a_3 x_k^3 + \cdots + a_{n-1} x_k^{n-1}) \\
= (a_0 x_k^0 + a_2 x_k^2 + \cdots + a_{n-2} x_k^{n-2}) +
x_k (a_1 x_k^0 + a_3 x_k^2 + \cdots + a_{n-1} x_k^{n-2})
\end{array} $$

음? 짝수차 항과 홀수차 항을 따로 모아 각각 다항식을 만들어보자.

$$\begin{array}{l}
f(x)=a_0 x + a_2 x^1 + a_4 x^2 + \cdots + a_{n-2} x_k ^{n/2-1} \\
g(x)=a_1 x + a_3 x^1 + a_5 x^2 + \cdots + a_{n-1} x_k ^{n/2-1} \\
A(x_k) = f(x_k^2) + x_k g(x_k^2)
\end{array} $$

식을 멋있게 정리는 했는데 각각의 다항식에 값을 대입한다고 해서 문제가 갑자기 해결되지는 않는다.
어차피 계산량은 동일하므로.

f와 g에 넣는 값을 보자.
$$x_k^2 = \left(\omega_{n}^{k}\right)^2 = \omega_{n} ^ {2k} = \omega_{n/2}^{k}$$
f와 g가 둘 다 ${n/2}-1$차 다항식이고, 여기에 $\omega_{n/2}^k$를 대입한 값을 쓰고 있다.

오잉? $(n-1)$차 다항식 A에 대해 $A(\omega_{n}^k)$를 대입한 값을 구하기 위해서 필요한 게,
$(n/2-1)$차 다항식 f와 g에 대해 $A(\omega_{n/2}^k)$를 대입한 값이다.
n=8일 때의 예시를 보면 이해가 빠를 것이다.

$$\begin{array}{l}
A(\omega^0) =f(\omega^0) + \omega^0 g(\omega^0) = f(\Omega^0) + \omega^0 g(\Omega^0) \\
A(\omega^1) =f(\omega^2) + \omega^1 g(\omega^2) = f(\Omega^1) + \omega^1 g(\Omega^1) \\
A(\omega^2) =f(\omega^4) + \omega^2 g(\omega^4) = f(\Omega^2) + \omega^2 g(\Omega^2) \\
A(\omega^3) =f(\omega^6) + \omega^3 g(\omega^6) = f(\Omega^3) + \omega^3 g(\Omega^3) \\
A(\omega^4) =f(\omega^8) + \omega^4 g(\omega^8) = f(\Omega^4) + \omega^4 g(\Omega^4) \\
A(\omega^5) =f(\omega^{10}) + \omega^5 g(\omega^{10}) = f(\Omega^5) + \omega^5 g(\Omega^5) \\
A(\omega^6) =f(\omega^{12}) + \omega^6 g(\omega^{12}) = f(\Omega^6) + \omega^6 g(\Omega^6) \\
A(\omega^7) =f(\omega^{14}) + \omega^7 g(\omega^{14}) = f(\Omega^7) + \omega^7 g(\Omega^7)
\end{array} $$

편의를 위해 $\omega = \omega_8, \Omega = \omega_4$로 썼다.
앞서 말했듯이 이걸 한 줄 한 줄 단순 대입으로 계산하면 계산량에 변화가 없으므로 아무런 진전이 없다.

그런데 식의 오른쪽에서 $f(\Omega^4)$ 같은 항이 보인다. $\Omega$는 네제곱하면 1이 되는 수이므로, 이건 $f(\Omega^0)$과 같은 값이다.
마찬가지로 $f(\Omega^5) = f(\Omega^1)$이고, 이런 식으로 한 번 계산한 값을 또 쓰는 부분이 계속 있어 보인다.
다시 정리를 해본다.

$$
A(\omega^0) = f(\Omega^0) + \omega^0 g(\Omega^0) \\
A(\omega^1) = f(\Omega^1) + \omega^1 g(\Omega^1) \\
A(\omega^2) = f(\Omega^2) + \omega^2 g(\Omega^2) \\
A(\omega^3) = f(\Omega^3) + \omega^3 g(\Omega^3) \\
\space \\
A(\omega^4) = f(\Omega^0) + \omega^4 g(\Omega^0) \\
A(\omega^5) = f(\Omega^1) + \omega^5 g(\Omega^1) \\
A(\omega^6) = f(\Omega^2) + \omega^6 g(\Omega^2) \\
A(\omega^7) = f(\Omega^3) + \omega^7 g(\Omega^3)
$$

$A(x_0), \: A(x_1), \cdots$를 차례대로 구하는 게 아니라, 먼저 f와 g에 대한 모든 값을 먼저 계산해 둔 후 가져다 쓰면 뭔가 절약될 듯한 느낌이 든다.
실제로 이렇게 하는 것만으로 시간이 확 줄어든다.
n짜리 다항식에 모든 x좌표를 대입한 값을 구하는 데에 걸리는 시간을 써보면 $T(n)=2T(n/2) + O(n)$ 이므로, $T(n) = O(n \log n)$이다. (시간복잡도 분석)
이제 이런 코드를 짤 수 있다.

typedef complex<double> C;
const double pi = acos(-1);

vector<C> fft_full_rec(vector<int> coeff){
    int n = coeff.size();
    if(n == 1){
        vector<C> ret;
        ret.push_back(coeff[0]);
        return ret;
    }
    C omega = polar(1., 2*pi/n);
    vector<int> a[2];
    for(int i=0; i<n; ++i) a[i&1].push_back(coeff[i]);
    vector<C> fa=fft_full_rec(a[0]);
    vector<C> fb=fft_full_rec(a[1]);
    vector<C> fc;
    for(int k=0; k<n; ++k) fc.push_back(fa[i%(n/2)] + pow(omega, i)*fb[i%(n/2)]);
    return fc;
}

이게 '변환'인 이유는 일단 어떤 주어진 값들을 다른 값들로 바꿨기 때문이다.
n개의 수((n-1)차 다항식의 계수 갯수만큼)를 가지고 n개의 수($\omega^0$부터 $\omega^{(n-1)}$까지 대입)를 얻었기 때문에 이 과정에서 정보의 손실도 없었다.

보통 '푸리에 변환'(Fourier transform)이라고 일컫는 개념이 여럿인데 여기서 사용한 것은 "이산 푸리에 변환"(discrete Fourier transform)이다. 유한한 원소에 대해서 $e^(ik \cdot j)$를 곱해서 합하기 때문이다. 수학, 자연과학 및 공학에서 푸리에 변환을 말하는 경우 주로 정의역(여기서는 {0, 1, ..., n-1}이었다)이 연속적인 함수의 함숫값에 $e^(ik \cdot x)$를 곱해서 적분하는 식으로 계산하는 푸리에 변환을 말한다. k는 주파수로 여기며 이 푸리에 변환이 가지는 여러 성질이 있다.

조금 더 빠르게

시간복잡도는 동일하지만 더 빠르게 작동하는 구현에 대해서 논한다.

위의 코드는 계속해서 vector를 생성하고 넘겨주기 때문에 조금 느리다. 즉 짝수와 홀수차 항을 분리할 때 새롭게 vector를 만드는 게 조금 오래 걸린다. (미리 크기를 reserve해주면 조금 낫긴 하지만, 지속적인 메모리 할당 및 해제가 일어나는 점은 마찬가지이다.)
그런데 실은 짝수차 항은 나의 0번부터 두 칸씩, 홀수차는 1번부터 두 칸씩 건너뛰는 것일 뿐이다.
따라서 계수의 배열은 그대로 계속 쓰고, 이 정보(시작점, 건너뛰는 너비, 항의 갯수)만 넘겨주면 이 문제가 해결된다.

또 아까 (n/2)승을 취하면 -1이 되는 점을 이용하면 제곱 부분을 다음과 같이 계산할 수 있다.
$$
A(\omega^0) = f(\Omega^0) + \omega^0 g(\Omega^0) \\
A(\omega^1) = f(\Omega^1) + \omega^1 g(\Omega^1) \\
A(\omega^2) = f(\Omega^2) + \omega^2 g(\Omega^2) \\
A(\omega^3) = f(\Omega^3) + \omega^3 g(\Omega^3) \\
\space \\
A(\omega^4) = f(\Omega^0) + \omega^4 g(\Omega^0) = f(\Omega^0) - \omega^0 g(\Omega^0) \\
A(\omega^5) = f(\Omega^1) + \omega^5 g(\Omega^1) = f(\Omega^1) - \omega^1 g(\Omega^1) \\
A(\omega^6) = f(\Omega^2) + \omega^6 g(\Omega^2) = f(\Omega^2) - \omega^2 g(\Omega^2) \\
A(\omega^7) = f(\Omega^3) + \omega^7 g(\Omega^3) = f(\Omega^3) - \omega^3 g(\Omega^3) \\
$$
이게 도움이 되는 이유는 $\omega^{k}$를 계산하는 과정이 실수 연산이므로 속도가 느려서 그렇다.
이걸 이용하면 반환값을 적는 배열 역시 매번 새로 만들 필요가 없다.
아래 표와 같이 위치를 정해주면, 두 값을 가져오는 위치와 계산해서 넣는 위치가 동일하기 때문이다.

원래값 새로운값
f(^0) f(^0) + ω^0 g(^0)
f(^1) f(^1) + ω^1 g(^1)
f(^2) f(^2) + ω^2 g(^2)
f(^3) f(^3) + ω^3 g(^3)
g(^0) f(^0) - ω^0 g(^0)
g(^1) f(^1) - ω^1 g(^1)
g(^2) f(^2) - ω^2 g(^2)
g(^3) f(^3) - ω^3 g(^3)

이를 구현한 코드는 다음과 같다.

void fft_quick_rec_do(vector<int>& coeff, vector<C>& ret, int start, int step, int n, int save_to){
    if(n == 1){
        ret[save_to] = coeff[start];
        return;
    }
    C omega = polar(1., 2*pi/n);

    fft_quick_rec_do(coeff, ret, start     , step*2, n/2, save_to);
    fft_quick_rec_do(coeff, ret, start+step, step*2, n/2, save_to+n/2);

    for(int i=0; i<n/2; ++i){
        auto a=ret[save_to+i];
        auto b=pow(omega, i)*ret[save_to+i+n/2];
        ret[save_to+i] = a+b;
        ret[save_to+i+n/2] = a-b;
    }
}

vector<C> fft_quick_rec(vector<int> coeff){
    vector<C> ret(coeff.size());
    fft_quick_rec_do(coeff, ret, 0, 1, coeff.size(), 0);
    return ret;
}

조금 더 머리를 쓰면 재귀호출마저 없앨 수 있다.
위의 코드에서, 각 단계에서 특정 비트가 0인 것과 1인 것을 분리해서 처리한 다음에 합치는 과정을 이해하는 것이 우선이다.

   앞쪽   처리중   공통
......        0    011  <- 이것들과
......        1    011  <- 이것들을 각각 재귀적으로 처리한 후,

   앞쪽   처리중   공통
111010        0    011  =: a
111010        1    011  =: b로 가져온 값이

0        111010    011  := a+ω_(2^7)^(111010) b
1        111010    011  := a-ω_(2^7)^(111010) b
로 들어가는 과정을 모든 "앞쪽"에 대해 진행한다. (서로 겹치지 않게)

코드를 짜면

vector<C> fft_nonrec(vector<int> coeff){
    vector<C> ret(coeff.size()), tmp(coeff.size());
    int n=0;
    while((1<<n) != int(coeff.size())) ++n;
    for(int i=0; i<(1<<n); ++i) ret[i]=coeff[i];

    for(int depth=n-1; 0<=depth; --depth){
        int sz_down = (1<<depth);
        int sz_up = (1<<(n-depth-1));
        C omega = polar(1., 2*pi/(sz_up*2));
        for(int down=0; down<sz_down; ++down){
            for(int up=0; up<sz_up; ++up){
                C a = ret[(up << (depth+1)) | down];
                C b = ret[(up << (depth+1)) | down | (1 << depth)];
                b *= pow(omega, up);
                tmp[(up << depth) | down] = a + b;
                tmp[(up << depth) | down | (1 << (n-1))] = a - b;
            }
            for(int up=0; up<sz_up*2; ++up){
                ret[(up << depth) | down] = tmp[(up << depth) | down];
            }
        }
    }
    return ret;
}

이 방법이 재귀 호출보다 많이 괜찮긴 한데 여기서 심지어 더 고속화할 수 있다.
이 방법이 그나마 느린 이유는, 위의 예시에서 아래의 두 값

0        111010    011
1        111010    011

이 멀리(2^9만큼) 떨어져 있기 때문이다.
아래쪽 비트를 그대로 두고 위쪽 비트로 반복문을 돌리고 있는데, 이러면 참조하는 값들 사이의 간격이 점점 멀어지면서 점점 캐시 미스가 발생한다.

그럼 반환값 배열에 사용할 인덱스의 비트를 모두 좌우로 뒤집으면 어떨까?

공통   처리중    뒷쪽
 110       0  ......  <- 이것들과
 110       1  ......  <- 이것들을 각각 재귀적으로 처리한 후,

공통   처리중    뒷쪽
 110       0  010111  =: a
 110       1  010111  =: b 로 가져온 값을

 110  010111       0  := a+ω_(2^7)^(111010) b
 110  010111       1  := a-ω_(2^7)^(111010) b
에 넣는 과정을 모든 "뒤쪽"에 대해 반복한다. (서로 겹치지 않게)

이렇게 하고 공통된 부분을 먼저 정하고 나면 그 아래쪽에서는 캐시 미스가 덜 난다.
"처리중"에 해당하는 비트가 바뀔 때 간격이 멀어지기는 하지만 이건 어쩔 수 없고, 그 다음에 "뒷쪽"으로 돌리는 반복문은 계속 연속한 값을 참조하므로 캐시 히트가 잘 발생한다.

그리고 한번 더 충격적인 최적화를 한다.
가져온 값을 대입하는 위치가 가져온 위치와 다른데 그냥 그 자리에 대입해 버려도 된다.

공통   처리중    뒷쪽
 110       0  010111  =: a
 110       1  010111  =: b 로 가져온 값을

 110       0  010111  := a+ω_(2^7)^(010111) b
 110       1  010111  := a-ω_(2^7)^(010111) b 에 대입. 지수가 이상한 것 같다면 아래를 읽어보자.

그러면 인덱스가 완전히 바뀌지 않나? 뭔가 문제가 있을듯 한데? 특히 ω_(2^7)^(010111)에서 인덱스를 왜 똑바로 썼을까?
일단 이렇게 하면 "처리중"인 비트가 "뒷쪽" 비트 문자열의 제일 앞에 가서 붙고 있는 모양새이다.
그런데 뒤집기 전의 예시에서 우리가 111010 x (110) -> x 111010 (110) 를 했었다.
지금 보니 그 때도 제일 앞에 비트를 붙였던 것이다.

우리가 (110) 0 010111 인덱스에 써버린 값은 사실 전체 비트를 뒤집었기 때문에 처음 짠 비재귀에서 0 010111 (011) 위치에 쓰이는 값이다. 이 값에서 ω의 지수는 010111이 맞다.

최종적인 코드는 다음과 같다.
비트를 뒤집는 부분은, i를 뒤집은 결과가 j에 담기게 하고, i를 1 키울 때 비트 단위 자리올림이 되는 과정을 역순으로 똑같이 처리하는 것으로, 전명우 님 블로그(링크)의 소스를 참고했다.
아래 소스의 b *= pow(omega, down); 부분은, 어차피 down을 1씩 늘리며 처리하므로 지역변수를 두고 거기에 계속 omega를 곱해도 된다. 이러면 정확도는 떨어지고 속도는 빨라진다.

vector<C> fft_nonrec_improved(vector<int> coeff){
    vector<C> ret(coeff.begin(), coeff.end());

    int n=0;
    while((1<<n) != int(coeff.size())) ++n;

    for(int i=1, j=0; i<(1<<n); ++i){
        for(int bit=(n-1); 0<=bit; --bit){
            if(1 & (j >> bit)) j -= (1<<bit);
            else { j += (1<<bit); break; }
        }
        if(i < j) swap(ret[i], ret[j]);
    }

    for(int depth=0; depth<n; ++depth){
        int sz_down = (1<<depth);
        C omega = polar(1., 2*pi/(sz_down * 2));
        for(int up=0; up<(1<<n); up += (1<<(depth+1))){
            for(int down=0; down<sz_down; ++down){
                C a = ret[up | down];
                C b = ret[up | down | (1 << depth)];
                b *= pow(omega, down);
                ret[up | down] = a+b;
                ret[up | down | (1<<depth)] = a-b;
            }
        }
    }

    return ret;
}

보간, 그리고 Inverse Fourier transform

이제 길고도 긴 값 구하기의 과정이 끝났다. 뭘 하려고 했더라?

  • (n+m+1) 개의 서로 다른 x값을 A와 B에 대입한 결과를 각각 알기
  • (x, y)의 쌍이 (n+m+1)개 주어지면 이를 모두 지나는 (n+m)차 다항식 얻기
    이제 두 번째 과정에 대해 논한다.

이런 다항식을 구하는 문제는 보통 보간(도울 보 補 + 사이 간 間) 혹은 내삽(안 내 內 + 꽂을 삽 揷), 영어로는 interpolation이라고 부르는 문제와 관련이 있다.
예를 들어 섭씨 25도에서 물의 비열이 4.13이고, 30도에서 4.11 이라면, (단위 kJ/kg/K)
28도에서 물의 비열에 대한 관측 결과는 안 가지고 있지만 근처의 값들을 이용해서 어림짐작은 할 수 있다.
이를테면 두 점 (25, 4.13), (30, 4.11)을 지나는 직선을 그려 x=28에서의 값을 찾을 수도 있고,
점이 세 개쯤 있으면 셋 모두를 지나는 2차곡선을 그려 x=28에서의 값을 찾을 수도 있다.
이렇게 몇 개의 데이터를 바탕으로 모르는 데이터를 짐작하는 게 보간, 내삽, interpolation이다.
명칭을 정확히 말하자면 데이터 안쪽의 값을 어림하는 게 보간, 내삽, interpolation이고, 바깥쪽(위의 예시에서는 20˚C)을 어림하는 건 외삽(外揷), extrapolation이라고 부르기는 한다.

아무튼, 우리가 하고자 하는 걸 '다항식으로 보간하기'라고 바꾸어 말할 수 있다는 말이다.
그리고 다항식 곱셈이랑 관련 없는 케이스에서도 사람들은 "주어진 점들을 지나는 다항식"을 구하고 싶어했다.
답이 유일하다고 했기 때문에, 그 다항식을 Lagrange polynomial이라는 하나의 이름으로 부른다.
예를 들어 (3, 5)를 지나는 함수를 아무거나 만든다면 $f(x)=(x-3) \cdot Q(x) + 5$를 쓸 수가 있고,
어떤 g가 g(3)=5인 다항식이라면 g를 항상 $g(x)=(x-3) \cdot Q(x) + 5$ 꼴로 쓸 수 있다는 아이디어를 발전시키면 라그랑주 다항식이 된다.
그런데 일단 저걸 계산하는 과정은 $\Theta(n^2)$ 미만으로 접근하기 어려워 보인다. n개의 항 각각이 (n-1)개의 다항식의 곱으로 나타나 있다.

여기서 정말 입이 떡 벌어지는 황당한 내용이 나온다.
일단 ω의 켤레복소수를 취한다.
$$ \overline{\omega_{n}} = \textrm{e} ^ {-2 \pi i / n} = \cos \left(\frac{2 \pi}{n} \right) - i \sin \left( \frac{2 \pi}{n} \right) $$
푸리에 변환을 한 결과를 이걸 이용해 푸리에 변환하면 원래 식과 거의 같고, 크기가 n배만 된 결과가 나온다.

무슨 말인지 다시 한 번 보자.
$f(x)=x^2 + 2x + 3$이다. 이를 수열로 표현하면 $[3, 2, 1, 0]$이다. 여기에 푸리에 변환을 취하면 다음과 같다.
$$\begin{array}{l}f(\omega^0)=6+0i \\ f(\omega^1)=2+2i \\ f(\omega^2)=2+0i \\ f(\omega^3)=2-2i \end{array}$$
이제 ω의 지수 순서대로 값을 나열해 새로운 수열 $[6, 2+2i, 2, 2-2i]$를 만든다. 다항식으로 쓰면 $F(x)=(2-2i)x^3 + 2x^2 + (2+2i)x + 6$이다. 여기에 $\omega$의 켤레복소수를 가지고 푸리에 변환을 하면 다음과 같다.
$$\begin{array}{l}F(\overline{\omega^0})=12+0i \\ F(\overline{\omega^1})=8+0i \\ F(\overline{\omega^2})=4+0i \\ F(\overline{\omega^3})=0+0i \end{array}$$
여기서 $F(\overline{\omega^k})/4$의 값이 f의 k차항의 계수와 정확히 동일하다. 푸리에 역변환은 이와 같다.

사실은 이뿐만이 아니다.

  • 변환 결과의 켤레복소수에 변환을 취하면, 원래 결과의 n배가 나온다. (입력이 실수일 때만. 복소수라면 켤레복소수의 n배)
  • 변환 결과의 첫 번째를 제외한 나머지 값들을 인덱스 순으로 좌우로 뒤집은 후 변환을 취하면, 원래 결과의 n배가 나온다.
  • 변환 결과의 실수부-허수부를 서로 바꾸고 변환을 취한 후 다시 실수부-허수부를 바꾸면, 원래 결과의 n배가 나온다.

별 게 다 있다. 아무튼 inversion이 가능하며, 그 과정에서 이미 구현한 FFT를 거의 그대로 사용할 수 있다는 점이 중요하다.

역변환 증명

여기서는 $\overline{\omega_n}$을 사용하는 접근을 증명한다.

고등학교 수준의 증명, 그리고 선형대수학적인 (멋있고 빠른) 증명이 있다.

고등학교 수준의 증명을 먼저 보자. f의 각 계수의 값이 어디로 흘러가는지를 보는 것이다.
f의 i번째 계수 $c_i$는 $f(\omega^k)$에 어떤 형태로 들어가는가? $c_i \cdot \omega^{ki}$ 로 나타난다.
이제 다른 항의 값은 모두 무시하고, $c_i$에 붙는 값만 보자.
위에서 제시한 역변환을 거친 결과는, $F(x) = \sum_k=0^{n-1} f(\omega^k) \cdot x^k$라는 다항식 F(x)에 대해 ω의 켤레복소수로 푸리에 변환을 취한 값이다.
그러면 $F(\omega^j)=\sum_k=0^{n-1} f(\omega^k) \cdot \overline{\omega^{jk}}$이다.
$c_i$ 항만 남겨보면
$$\sum_k=0^{n-1} c_i \cdot \omega^{ki} \cdot \overline{\omega^{jk}} \\
= \sum_k=0^{n-1} c_i \cdot \omega^{k(i-j)} \\
= c_i \sum_k=0^{n-1} \omega^{k(i-j)}$$

그런데 $\sum_k=0^{n-1} \omega^{k(i-j)}$는 i=j일 때 모든 항이 1이 되어 n이고,
i≠j이면 각각의 지수가 모두 달라 $(1+\omega+\omega^2+\cdots+\omega^{(n-1)})$과 같은 값이다.
이 값은 실은 0이다. 이유는 $\omega^n-1=(1+\omega+\omega^2+\cdots+\omega^{(n-1)})(\omega-1)=0$인데 $\omega-1 \neq 0$이기 때문이다.
따라서 i=j인 경우 $c_i \cdot n$이 남고, i≠j인 경우 0만 나타나므로, 각각의 항이 최종 결과에 기여하는 관계를 따져보면 원래 값의 n배임을 알 수 있다.

선형대수학적 증명은 다음과 같다.
푸리에 변환을 취하는 것은, 입력을 세로로 쭉 늘어놓은 열벡터에 대해 다음 행렬을 곱하는 것과 동일하다.
$$ M_F = \begin{pmatrix} 1 & 1 & 1 & 1 & \cdots & 1 & 1 \\
1 & \omega & \omega^2 & \omega^3 & \cdots & \omega^{n-2} & \omega^{n-1} \\
1 & \omega^2 & \omega^4 & \omega^6 & \cdots & \omega^{2(n-2)} & \omega^{2(n-1)} \\
\vdots & & & & \ddots & & \\
1 & \omega^{n-2} & \omega^{(n-2)\cdot 2} & \omega^{(n-2)\cdot 3} & \cdots & \omega^{(n-2)(n-2)} & \omega^{(n-2)(n-1)} \\
1 & \omega^{n-1} & \omega^{(n-1)\cdot 2} & \omega^{(n-1)\cdot 3} & \cdots & \omega^{(n-1)(n-2)} & \omega^{(n-1)(n-1)}
\end{pmatrix} $$

먼저 이 행렬의 각 열벡터(대칭이라 각 행벡터와 동일)는 서로 수직임을 아는 것이 중요하다. $\mathbb{C}^n$의 벡터 사이에서 수직을 논할 때는 한쪽을 켤레복소수 취한 후 곱해야 함을 잊지 말자.
그러면 지수가 i인 열과 j인 열의 내적은 다음과 같다.
$$ \begin{array}{l} 1 + \overline{\omega^{i}} \omega^{j} + \overline{\omega^{2i}} \omega^{2j} + \cdots + \overline{\omega^{(n-1)i}} \omega^{(n-1)j} \\
= 1 + \omega^{j-i} + \omega^{2(j-i)} + \cdots + \omega^{(n-1)(j-i)} \\
= \begin{cases} n \quad \textrm{ if }i=j \\ 0 \quad \textrm{ otherwise} \end{cases} \end{array} $$

따라서 모든 열이 서로 수직이고, 크기는 $\sqrt{n}$이다.
실수의 경우 orthogonal matrix와 거의 똑같은 개념으로 복소수의 경우에 unitary matrix가 있다. 여기서 $(1/\sqrt{n}) M_F$는 unitary matrix이다.
orthogonal matrix의 경우 $A^{-1}=A^T$ 라는 성질이 있었듯,
unitary matrix의 경우 $U^{-1}=U^{*}=\overline{U}^T$라는 성질이 있다.
실제로 $\overline{M_F}^T \cdot M_F$의 i행 j열의 값은 방금 위에서 계산한 내적값이고, i와 j가 같을 때 n, 다를 때 0이므로
$$\overline{M_F}^T \cdot M_F = nI$$
그리고 $M_F$는 대칭이므로 $\overline{M_F}^T = \overline{M_F}$ 이고, $$M_F^{-1} = (1/n) \overline{M_F}$$라는 결론을 내릴 수 있다. 증명이 끝났다.

다항식 곱셈

이제 우리가 할 일은 명확하다.

  • A와 B에 FFT를 적용한다.
  • 적용한 결과끼리 곱한다.
  • 곱한 결과에 inverse FFT를 취해서 A와 B의 두 다항식을 곱한 결과를 얻는다.
    그런데 차수를 생각하니 뭔가 이상하다. n차 다항식에서 n개짜리 수열(y값)을 얻었고 inverse를 취했으므로, 결과도 n개의 숫자이다. n차 다항식 두 개를 곱한 결과가 n차 다항식이 된다는 점이 모순되어 보인다.

사실 우리가 얻은 결과는 circular convolution이라고 부른다. 처음에 이런 코드를 작성했었다.

for(int i = 0; i <= n; ++i)
    for(int j = 0; j <= m; ++j)
        c[i+j] += a[i] * b[j];

그런데 우리가 실제로 얻은 결과는 실은 다음과 같다.

for(int i = 0; i < n; ++i)
    for(int j = 0; j < n; ++j)
        c[(i+j)%n] += a[i] * b[j];

왜 이렇게 되는가? $C(x)=A(x)B(x)$라고 하고, C에서 n차항 계수를 0차항에 더하고, (n+1)차항 계수는 1차항에, ... 이렇게 n차 이상의 항을 모두 아래쪽으로 깎아 더한 다항식을 $C^*(x)$라고 하자.
그러면 $\omega^n=1$이기에 $C(\omega^k) = C^*(\omega^k)$ 임을 알 수 있다.
즉 $C^*(x)$는 $C^*(\omega^k)=A(\omega^k)B(\omega^k)$를 완벽하게 만족하고, n개의 점을 모두 지나는 (n-1)차 다항식이 유일하게 존재한댔으니 우리가 얻는 것은 $C^*(x)$인 것이다.

순수하게 2n개의 숫자를 모두 얻고 싶다면 방법은 간단하다. A와 B에 0을 추가해 $(2^t-1)$차로 만들었다면, 여기서 0을 더 추가해 $(2^{(t+1)}-1)$차로 만든다.

이제 다항식 곱셈을 다음과 같이 할 수 있다.

vector<C> multiply(vector<C> a, vector<C> b){
    int n = a.size() - 1;
    int m = b.size() - 1;
    int sz = 1;
    while(sz < n+1 || sz < m+1) sz *= 2;
    sz *= 2; // make sure no circular collision occurs
    a.resize(sz); b.resize(sz);

    fft(a); fft(b);
    for(int i=0; i<sz; ++i) a[i] *= b[i];
    ifft(a);
    return a;
}

소수체에서의 FFT

소수체라는 이름이 맞는지 모르겠는데, $\mathbb{Z}/p\mathbb{Z}$로 쓰기도 하고, 아무튼 원소가 정수이고, 비교와 덧셈과 곱셈이 modulo p로 정의하는 체가 있다. 이 위에서도 FFT가 가능하다.
사실 위에서 설명한 푸리에 변환과 FFT의 내용은 $\omega=e^{2\pi i / n}$ 관련된 내용만 제외하면 모든 체에 대해 맞는다.
증명을 보면, ω가 만족해야 하는 조건은 다음뿐이다.
$$\begin{array}{l} \omega^n = 1 \\ \omega^k \neq 1\textrm{ for all }k \neq 0, n, 2n, \cdots \end{array}$$

예를 들어 $p=998\:244\:353=119 \cdot 2^{23} + 1, \quad n=2^{23}$을 생각하자.
어떤 a에 대해 $(a^{119})^{n}=a^{p-1} \equiv 1\textrm{ mod }p$이다. (페르마 小정리에 의해)
그럼 모든 $a^{119}$가 ω가 될 수 있을까? 그렇지는 않다. 중간에 1이 될 수도 있기 때문이다.
해결책은 a를 p의 원시근으로 잡는 것이다.

3은 이 p의 원시근임이 알려져 있다. 즉 $3^0, 3^1, \cdots, 3^{p-1}$은 모두 서로 다르며, 이 $p$개의 숫자는 각각 $0, 1, \cdots, p-1$ 중 하나에 일대일 대응된다.
그러면 $(3^{119})^k$도 $k=2^{23}$이 되기 전까지는 1이 되지 못한다.
따라서 $n=2^{23}$일 때 $\omega=3^{119}$를 잡으면 된다.
더 작은 n에 대해서도 쓸 수 있다. n이 절반이 될 때마다 ω를 제곱시켜주면 된다.
예를 들어 $n=2^{20}$이면 $\omega=3^{119 \cdot 2 ^ 3}$으로 잡으면 된다.

'알고리즘 문제풀기 > 기타 주제' 카테고리의 다른 글

Code::Blocks에서의 편한 설정  (0) 2016.01.13
CMS Green  (0) 2015.08.20
unique한 bipartite matching의 성질  (0) 2015.07.05
불변성 원리  (1) 2015.02.14
Prefix sum  (0) 2013.07.29