Computer Science/알고리즘

[알고리즘] 분할 정복 - 쉬트라센의 행렬 곱셈 (Divide and Conquer - Strassen's Matrix Multiplication)

바보1 2022. 4. 5. 00:07

1. 행렬 곱셈에 대한 간단한 소개

 

예전의 글에서 구한 행렬 곱셈의 시간 복잡도는 \(T(n)\,=\,n^3\)이었습니다.

 

왜냐? 두 개의 행렬에 대해 A의 행, B의 열을 곱해야하기 때문입니다. A의 n개의 행에 대해서 B는 \(n^2\)번을 곱해야하기 때문입니다.

 

하지만 Strassen씨는 시간 복잡도가 \(n^3\)보다 더 빠른 알고리즘을 개발했습니다.


2. Strassen's Algorithm에 대한 소개

 

행렬 A)

\(\begin{bmatrix}a_{11}&a_{12}  \\a_{21}&a_{22}  \\\end{bmatrix}\)

 

행렬 B)

\(\begin{bmatrix}b_{11}&b_{12}  \\b_{21}&b_{22}  \\\end{bmatrix}\)

 

행렬 C)

\(\begin{bmatrix}c_{11}&c_{12}\\ c_{21}&c_{22}\\ \end{bmatrix}\) = \(\begin{bmatrix}a_{11}&a_{12}  \\a_{21}&a_{22}  \\\end{bmatrix}\)x\(\begin{bmatrix}b_{11}&b_{12}  \\b_{21}&b_{22}  \\\end{bmatrix}\)

이고,

 

\(m_{1}\,=\,(a_{11}+a_{22})(b_{11}+b_{22})\)

\(m_{2}\,=\,(a_{21}+a_{22})b_{11}\)

\(m_{3}\,=\,a_{11}(b_{12}-b_{22})\)

\(m_{4}\,=\,a_{22}(b_{21}-b_{1})\)

\(m_{5}\,=\,(a_{11}+a_{12})b_{22}\)

\(m_{6}\,=\,(a_{21}-a_{11})(b_{11}+b_{12})\)

\(m_{7}\,=\,(a_{12}-a_{22})(b_{21}+b_{22})\)

일 때, 

 

행렬 C는

\(\begin{bmatrix}m_{1}+m_{4}-m_{5}+m_{7}&m_{3}+m_{5}  \\ m_{2}+m_{4}&m_{1}+m_{3}-m_{2}+m_{6}  \\ \end{bmatrix}\)

입니다.

 

참 신기하죠? 저도 혹시 몰라서 직접 다 계산해봤는데 맞더라구요..

 

암튼 대충 감이 잡히시나요?

 

 

암튼 행렬이 2의 거듭 제곱일 때, 정확히 반씩 4등분 한 다음에 각각 Strassen's의 알고리즘을 적용해서 정복하면 되겠습니다.


3. 알고리즘 및 코드

 

problem : 행렬 A와 행렬 B를 곱한 행렬 C

Input : 양의 정수 n, Threshold, n*n 행렬 A, n*n행렬 B (이떄 A와 B가 곱할 수 없는 경우는 제외)

Output : 행렬 A와 행렬 B가 곱해진 행렬 C

 

알고리즘)

1. 행렬 A와 B를 4등분한다.

2. 4등분한 A의 subMatrix, B의 subMatrix를 Strassen's의 공식에 맞게 대입한다.

3. 이때, subMatrix도 곱해야하는데, Strassen's 함수를 재귀 호출한다.

4. 인자로 들어온 행렬의 변의 길이가 threshold 이하일 때, Strassen's의 알고리즘을 쓰지말고, 단순한 행렬 곱셈을 실시한다.

 

주의점)

1. n이 2의 거듭제곱이 아니라면 2의 거듭제곱이 되게 나머지 부분에 0을 채워준다.

 

코드, Python)

import sys
import copy
input = sys.stdin.readline

def check_power_of_2(N):
    if (N & -N) == N:
        return N
    else:
        k = 1
        while k < N:
            k *= 2
        return k


def add_zero(M, idx, N):
    for i in range(N):
        M[i].extend([0 for _ in range(idx)])
    M.extend([0 for _ in range(N + idx)] for _ in range(idx))

    return M


def print_matrix(N, M):
    for i in range(N):
        print(*M[i][:N])


def partition(m, M, M11, M12, M21, M22):
    for i in range(m):
        for j in range(m):
            M11[i][j] = M[i][j]
            M12[i][j] = M[i][j + m]
            M21[i][j] = M[i+m][j]
            M22[i][j] = M[i+m][j+m]


def mmult(N, A, B, C):
    for i in range(N):
        for j in range(N):
            for k in range(N):
                C[i][k] += A[i][j] * B[j][k]


def madd(N, A, B, C):
    # print(A, B ,N)
    for i in range(N):
        for j in range(N):
            C[i][j] = A[i][j] + B[i][j]


def msub(N, A, B, C):
    for i in range(N):
        for j in range(N):
            C[i][j] = A[i][j] - B[i][j]


def resize(m, M):
    M.extend([[0 for _ in range(m)] for _ in range(m)])


def combine(m, C, C11, C12, C21, C22):
    for i in range(m):
        for j in range(m):
            C[i][j] = C11[i][j]
            C[i][j+m] = C12[i][j]
            C[i+m][j] = C21[i][j]
            C[i+m][j+m] = C22[i][j]


def strassen(N, A, B, C):
    global Threshold
    global count
    count += 1
    if N <= Threshold:
        mmult(N, A, B, C)
    else:
        m = N//2

        main_dict = {f'{c}{i}{j}': [] for c in ['A', 'B', 'C'] for i in range(1, 3) for j in range(1, 3)}
        main_dict.update({f'M{i}': [] for i in range(1, 8)})
        main_dict.update({f'{c}': [] for c in ['L', 'R']})

        for key in list(main_dict.keys()):
            resize(m, main_dict[key])

        # print(main_dict)

        partition(m, A, *[main_dict[f'{c}'] for c in list(main_dict.keys()) if 'A' in c])
        partition(m, B, *[main_dict[f'{c}'] for c in list(main_dict.keys()) if 'B' in c])

        # print(A21, A22, m)

        A11, A12, A21, A22 = [main_dict[f'{c}'] for c in list(main_dict.keys()) if 'A' in c]
        B11, B12, B21, B22 = [main_dict[f'{c}'] for c in list(main_dict.keys()) if 'B' in c]
        C11, C12 ,C21, C22 = [main_dict[f'{c}'] for c in list(main_dict.keys()) if 'C' in c]
        M1, M2, M3, M4, M5, M6, M7 = [main_dict[f'{c}'] for c in list(main_dict.keys()) if 'M' in c]
        L, R = main_dict['L'], main_dict['R']

        # for c in ['A', 'B', 'C']:
        #     for i in range(1, 3):
        #         for j in range(1, 3):
        #             locals()[f'{c}{i}{j}'] = main_dict[f'{c}{i}{j}']
        #
        # for i in range(1, 8):
        #     locals()[f'M{i}'] = main_dict[f'M{i}']
        #
        # for c in 'L', 'R':
        #     locals()[f'{c}'] = main_dict[f'{c}']

        madd(m, A11, A22, L)
        madd(m, B11, B22, R)
        strassen(m, L, R, M1)       ###M1

        madd(m, main_dict['A21'], main_dict['A22'], L)
        strassen(m, L, B11, M2)     ###M2

        msub(m, B12, B22, R)
        strassen(m, A11, R, M3)     ###M3

        msub(m, B21, B11, R)
        strassen(m, A22, R, M4)     ###M4

        madd(m, A11, A12, L)
        strassen(m, L, B22, M5)     ###M5

        msub(m, A21, A11, L)
        madd(m, B11, B12, R)
        strassen(m, L, R, M6)       ###M6

        msub(m, A12, A22, L)
        madd(m, B21, B22, R)
        strassen(m, L, R, M7)       ###M7

        # print(M1, M2, M3, M4, M5, M6, M7)

        madd(m, M1, M4, L)
        msub(m, L, M5, L)
        madd(m, L, M7, C11)          ###C1

        madd(m, M3, M5, C12)        ###C2

        madd(m, M2, M4, C21)        ###C3

        madd(m, M1, M3, L)
        msub(m, L, M2, L)
        madd(m, L, M6, C22)          ###C4

        combine(m, C, C11, C12, C21, C22)



N, Threshold = map(int, input().split())

k = check_power_of_2(N)
idx = k - N

A = [list(map(int, input().split())) for _ in range(N)]
B = [list(map(int, input().split())) for _ in range(N)]

if idx != 0:
    A = add_zero(A, idx, N)
    B = add_zero(B, idx, N)

C = [[0 for _ in range(k)] for _ in range(k)]

count = 0

strassen(k, A, B, C)

print(count)
print_matrix(N, C)

생각보다 길죠?

 

제가 다음부터는 주석 처리를 하겠습니다.

작은 함수들은 이해하기 쉬우실겁니다. 근데 Strassen's 함수는 아마 이해하시기 어려울겁니다.

왜냐면 전역 변수의 남발, main_dict의 남발 등등으로 인해서,,

나중에 시간이 된다면 좀 더 최적화해서 올리겠습니다.


4. 시간 복잡도 분석

 

bisic operation : 곱셈 연산

 

기존의 행렬 곱셈의 시간 복잡도는 \(n^3 - n^2\)이고, \(\Theta(n^3)\)이었습니다.

 

하지만 쉬트라센의 알고리즘을 보시면, 총 곱셈이 7번 하는 걸 볼 수 있습니다.

 

또한 덧셈/뺄셈 연산을 총 18번 합니다.

 

즉 쉬트라센의 시간 복잡도는 \(T(n)\,=\,7T\left ( \frac{n}{2} \right )+18\left ( \frac{n}{2} \right )^{2}\)이라고 볼 수 있습니다.

 

이는 곧 \(6n^{lg 7}-6n^2\,\epsilon\,\theta(n^{2.81})\)임을 의미합니다.

 

생각보다 개선이 안 됐다고 생각하실 수 있는데, n이 무한히 커지면 시간 차이는 아득합니다 ㅎㅎ

 

현재 행렬 곱셈은 최소 \(\Theta(n^{2.38})\)까지 나와잇습니다.

그리고 행렬 곱셈의 최소 시간 복잡도는 \(\Omega(n^2) \)입니다.

이 이하로 떨어질 수가 없습니다. 왜냐면 n * n 이기 때문에..

 

아무튼 어려운 알고리즘이므로 잘 공부하셨으면 좋겠습니다.

 

감사합니다.

 

 

 

지적 환영합니다.