본문 바로가기
백준/트리

백준 6497 c++ "전력난" -PlusUltraCode-

by PlusUltraCode 2025. 10. 1.

https://www.acmicpc.net/problem/6497

[필자 사고]

MST 항상 미뤄왔떤 알고리즘이지만 막상 풀어보니 그렇게 어려운 알고리즘이 아니였다.

개념은 아래와 같다.

모든 노드들을 단일 선으로 연결한다 이게 스패닝 트리 이고 스패닝 트리는 여러개가 나올 수 있다. 

그 중 가중치의 합이 가장 작은걸 mst라고 한다.

MST를 만들때는 UNionFind를 이용하여 해당 노드가 같은 부모가 되었는지를 검사하는 형태로 이루어 진다.

 

아래는 자세한 코드 해설이다.

[코드 해설]

1. struct Node

  • 한 간선을 표현하는 구조체이다.
  • start, end, weight 멤버 변수를 가지고 있으며, 각각 간선의 시작점, 끝점, 가중치를 의미한다.
  • operator<를 오버로딩해서 sort 함수가 간선을 가중치 기준 오름차순으로 정렬할 수 있게 한다.

2. int find(int a)

  • Union-Find 자료구조의 "find 연산"이다.
  • 노드 a가 속한 집합의 루트(parent)를 찾는다.
  • 경로 압축(Path Compression)을 사용하여 재귀적으로 최상위 부모를 찾아갔다가, 찾은 부모를 현재 노드에 바로 저장한다.
  • 시간 복잡도를 거의 상수 시간으로 줄여주는 역할을 한다.

3. bool Union(int a, int b)

  • Union-Find 자료구조의 "union 연산"이다.
  • a와 b가 속한 집합을 합친다.
  • 두 노드의 루트를 각각 찾은 후, 루트가 같으면 같은 집합이므로 false를 반환한다.
  • 루트가 다르면 한쪽 루트를 다른 쪽에 연결하여 집합을 합치고, 이때 true를 반환한다.
  • MST 알고리즘에서 사이클을 방지하는 핵심 역할을 한다.

4. int main()

  • 프로그램의 시작점이다.
  • 무한 루프를 돌며 여러 테스트 케이스를 처리한다.
    1. N과 M을 입력받는다. 만약 (0,0)이면 프로그램을 종료한다.
    2. 간선 정보를 저장할 arr 벡터를 초기화하고, Union-Find를 위한 parent 배열을 준비한다. 초기에는 각 노드가 자기 자신을 부모로 갖는다.
    3. M개의 간선을 입력받아 arr에 저장하고, 모든 간선 가중치의 합을 total에 더한다.
    4. arr를 간선 가중치 기준으로 오름차순 정렬한다.
    5. Kruskal 알고리즘을 실행한다. 정렬된 간선을 하나씩 확인하면서 두 정점이 같은 집합에 속해 있지 않으면 Union을 하고, 그 간선의 가중치를 mst에 더한다.
    6. 최종적으로 total - mst를 출력한다. 이는 전체 간선 비용에서 MST 비용을 뺀 값으로, 문제에서 요구하는 절약할 수 있는 최대 비용이다.

[소스 코드]

#include <bits/stdc++.h>
using namespace std;

struct Node {
    int start, end, weight;
    bool operator<(const Node& other) const {
        return weight < other.weight;
    }
};

int N, M;
vector<Node> arr;
vector<int> parent;

int find(int a) {
    if (parent[a] == a) return a;
    return parent[a] = find(parent[a]);
}

bool Union(int a, int b) {
    a = find(a);
    b = find(b);
    if (a == b) return false;
    parent[b] = a;
    return true;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    while (true) {
        cin >> N >> M;
        if (N == 0 && M == 0) break;

        arr.clear();
        parent.assign(N, 0);
        for (int i = 0; i < N; i++) parent[i] = i;

        long long total = 0;
        for (int i = 0; i < M; i++) {
            int s, e, w;
            cin >> s >> e >> w;
            arr.push_back({ s, e, w });
            total += w;
        }

        sort(arr.begin(), arr.end());

        long long mst = 0;
        for (auto& edge : arr) {
            if (Union(edge.start, edge.end)) {
                mst += edge.weight;
            }
        }

        cout << total - mst << "\n";
    }
    return 0;
}