Computer Science/알고리즘

[알고리즘] 동적 계획법 - 연쇄 행렬 곱셈 코드 (Chained Matrix Multiplication Code)

바보1 2022. 4. 15. 15:44

문제)

Description

교재와 강의자료를 참고하여, Algorithm 3.6/3.7 연쇄 행렬 곱셈 알고리즘의 구현을 완성하라.


행렬의 개수 n과 각 행렬의 크기 값의 배열 d를 입력으로 받고

M, P 행렬의 값을 구해서 출력하고,

단위 곱셈의 최적 횟수 및 괄호로 묶은 행렬 곱셈의 순서를 출력하라.


단, 최적 횟수의 최대값은 999999를 넘지 않는다.

Input

첫 번째 줄에 행렬의 개수 n이 주어진다.

두 번째 줄부터 행렬의 크기 값의 배열 d가 주어진다.

Output

먼저 행렬 M의 윗 부분 삼각형을 출력한다. (0을 포함)

다음으로 행렬 P의 윗 부분 삼각형을 출력한다. (0을 포함)

M과 P를 출력한 후에 최적값을 출력한다.

다음 줄에 행렬 곱셈의 순서를 괄호로 묶어 출력한다.

모든 단위 행렬에도 괄호가 포함되어야 하고,

행렬 이름은 A1, A2, .... , An 으로 표기한다.

Sample Input 1

6
5 2 3 4 6 7 8

Sample Output 1

0 30 64 132 226 348
0 24 72 156 268
0 72 198 366
0 168 392
0 336
0
0 1 1 1 1 1
0 2 3 4 5
0 3 4 5
0 4 5
0 5
0
348
((A1)(((((A2)(A3))(A4))(A5))(A6)))

Sample Input 2

1
3 5

Sample Output 2

0
0
0
(A1)

코드)

import sys
import copy
input = sys.stdin.readline


def minmult(m, d, p, i, j):
    # i부터 j까지의 행렬 중 최소 곱셈 횟수를 구해줌
    minimum = 999999        # 최댓값은 999999를 넘지 않으므로
    for k in range(i, j):
        value = m[i][k] + m[k+1][j] + d[i] * d[k + 1] * d[j + 1]        # k를 기준으로 나눈 두 행렬을 곱함
        if minimum > value:
            minimum = value
            p[i][j] = k

    return minimum


def func(n ,d, m, p):
    # 행렬의 최소 곱셈 횟수를 구하는 함수
    # 행렬은 두 개씩만 곱했을 때부터 시작해서, 최종적으로 1부터 n까지의 곱셈을 한다.
    for diagonal in range(n-1, 0, -1):        # diagonal은 곱하는 갯수
        for i in range(1, diagonal + 1):        # diagonal과 같음
            j = i + n - diagonal        # 만약 5개를 곱하면 +1, 4개를 곱하면 +2를 하면 됨
            # 쉽게 말해서 개수가 5개면 (1,2) (2,3) .. (5,6)을 하면 되고, 개수가 1개면 (1,6)을 하면 됨
            m[i][j] = minmult(m, d, p, i, j)


def path(p, start, end):
    # 곱셈 순서를 알려주는 함수
    if start == end:
        print(f'(A{start})', end='')        # 두 개가 같다면 출력하면 됨
    elif start < end:
        k = p[start][end]
        print('(', end='')
        path(p, start, k)
        path(p, k + 1, end)
        print(')', end='')


n = int(input())        # 행렬의 개수
d = [0] + list(map(int, input().split()))        # 각 행렬의 크기 리스트
m = [[0] * (n + 1) for _ in range(n + 1)]     # 행렬의 최소 곱셈을 저장하는 배열, 기준을 1부터 시작하기 위해서 1씩 더함
p = copy.deepcopy(m)        # p는 어떤 행렬을 기준으로 나눴는지 알려줌

func(n, d, m, p)

for i in range(1, n + 1):
    print(*m[i][i:n+1], sep=' ')
for i in range(1, n + 1):
    print(*p[i][i:n+1], sep=' ')

print(m[1][n])

path(p, 1, n)