Algorithm

[Algorithm] 최소 스패닝 트리(MST) 문제

mxruhxn 2024. 12. 13. 10:33
728x90
반응형

도입

  • 무향 그래프의 스패닝 트리(spanning tree): 원래 그래프의 정점 전부와 간선의 부분 집합으로 구성된 부분 그래프
    • 스패닝 트리에 포함된 간선들은 정점들을 트리 형태로 전부 연결해야 함 (사이클 X, 정점들이 꼭 부모-자식 관계일 필요 X)
  • 그래프의 스패닝 유일하지 X
  • 최소 스패닝 트리(Minimum Spanning Tree, MST) 문제: 가중치 그래프의 스패닝 트리 중 가중치의 합이 가장 작은 트리를 찾는 문제
    • = 그래프의 연결성을 그대로 유지하는 가장 '저렴한' 그래프를 찾는 문제
  • MST 문제를 푸는 2가지 유명한 알고리즘
    • 크루스칼 알고리즘
    • 프림 알고리즘
  • 두 알고리즘 모두 간선이 하나도 없는 상태에서 시작해 하나씩 트리에 간선을 추가해 가는 탐욕적 알고리즘 => 결국 같은 방법으로 증명 가능

크루스칼 알고리즘(Kruskal Algorithm)

  • 크루스칼 알고리즘: 그래프의 모든 간선을 가중치의 오름차순으로 정렬한 뒤, 사이클이 생기지 않는 선에서 스패닝 트리에 순서대로 하나씩 추가해 가는 방식
    • 사이클이 생기는 순간 더 이상 트리 형태라고 볼 수 없음

자료 구조의 선택

  • 간선의 목록을 얻어 가중치 순서대로 정렬한 뒤, 순회하며 이들을 스패닝 트리에 추가할 것 => 추가할 간선으로 인해 사이클이 발생하는지 여부 판단이 핵심
  • 어떤 간선을 추가해서 그래프에 사이클이 생기려면, 간선의 양 끝 점이 같은 컴포넌트에 속해있어야 함
    • => 두 정점이 주어졌을 때 이들이 같은 컴포넌트(집합)에 속하는지 확인 후, 아닌 경우에만 컴포넌트를 합치는 연산 수행
    • => 유니온 파인드를 통해 효율적으로 구현 가능
    • 각 상호 배타적 집합은 그래프의 한 컴포넌트를 표현

크루스칼 알고리즘의 구현

public class Kruskal {
    static int V;
    static int[] parent, rank; // for 유니온 파인드
    static ArrayList<Node>[] graph;

    public static int kruskal(List<SelectedEdge> selected) {
        int ret = 0;

        // (가중치, (정접1, 정점2))의 목록을 얻는다 : 간선의 목록 얻기
        List<Edge> edges = new ArrayList<>();
        for (int u = 0; u < V; u++) {
            for (int i = 0; i < graph[u].size(); i++) {
                int v = graph[u].get(i).x;
                int cost = graph[u].get(i).weight;

                edges.add(new Edge(cost, u, v));
            }
        }

        // 가중치 기준 오름차순 정렬
        edges.sort(Comparator.comparingInt(o -> o.weight));

        // 처음엔 모든 정점이 서로 분리
        for (int i = 0; i < edges.size(); i++) {
            // 간선 (u, v) 검사
            Edge currEdge = edges.get(i);

            // 이미 u와 v가 연결되어 있을 경우 -> 사이클이 생기므로 무시
            if (find(currEdge.u) == find(currEdge.v)) continue;

            // 아니라면 이 둘을 합친다
            union(currEdge.u, currEdge.v);
            selected.add(new SelectedEdge(currEdge.u, currEdge.v));
            ret += currEdge.weight;
        }

        return ret;
    }

    public static int find(int x) {
        if (parent[x] == x) return x;
        return parent[x] = find(parent[x]);
    }

    public static void union(int x, int y) {
        x = find(x);
        y = find(y);

        // 이미 같은 집합에 속해있는 경우 걸러내기
        if (x == y) return;

        // 항상 y의 높이가 더 크도록
        if (rank[x] > rank[y]) {
            int temp = x;
            x = y;
            y = temp;
        }

        parent[x] = y; // x를 y의 자식으로 넣기

        // 두 트리의 높이가 같은 경우 높이 1 증가
        if (rank[x] == rank[y]) ++rank[y];
    }

    static class Node {
        int x;
        int weight;

        public Node(int x, int weight) {
            this.x = x;
            this.weight = weight;
        }
    }

    static class Edge implements Comparable<Edge> {

        int weight;
        int u;
        int v;

        public Edge(int weight, int u, int v) {
            this.weight = weight;
            this.u = u;
            this.v = v;
        }

        @Override
        public int compareTo(Edge o) {
            return Integer.compare(weight, o.weight);
        }
    }

    static class SelectedEdge {
        int u;
        int v;

        public SelectedEdge(int u, int v) {
            this.u = u;
            this.v = v;
        }
    }

    public static void main(String[] args) {
        V = 7;

        parent = new int[V];
        rank = new int[V];
        graph = new ArrayList[V];

        for (int i = 0; i < V; i++) {
            parent[i] = i;
            rank[i] = 1;
            graph[i] = new ArrayList<>();
        }

        addEdge(0, 2, 1);
        addEdge(0, 1, 5);
        addEdge(1, 3, 1);
        addEdge(1, 5, 3);
        addEdge(1, 6, 3);
        addEdge(2, 3, 4);
        addEdge(3, 4, 5);
        addEdge(3, 5, 3);
        addEdge(5, 6, 2);

        List<SelectedEdge> selected = new ArrayList<>();
        int cost = kruskal(selected);

        System.out.println(cost);

        for (SelectedEdge selectedEdge : selected) {
            System.out.println(selectedEdge.u + " " + selectedEdge.v);
        }
    }

    private static void addEdge(int u, int v, int weight) {
        graph[u].add(new Node(v, weight));
        graph[v].add(new Node(u, weight));
    }
}
  • SelectedEdge 클래스는 단순히 선택된 정점 집합을 얻기 위한 클래스이므로, 필요에 따라 사용할 것
  • 간선 목록 얻기 -> 정렬하기(Comparable 구현 필요) -> 순서대로 뽑아서 검사 & 추가
    • 추가 시 사이클 발생 여부 판단을 위해 유니온 파인드 자료구조 사용

정당성 증명

  • 크루스칼 알고리즘은 각 간선을 그래프에 추가할 때 뒤에 오는 간선들에 대한 고려 X => 탐욕적 알고리즘
    • => 탐욕적 알고리즘과 같은 형태로 정당성 증명 가능
  • 탐욕적 선택 속성 증명
    • 어떤 간선을 트리에 추가하기로 결정했을 때, 이로 인해 최소 스패닝 트리를 찾을 수 없게 되는 일이 없음을 증명
    • 증명 과정
      • 크루스칼 알고리즘이 선택하는 간선 중 최소 스패닝 트리 T에 포함되지 않는 간선이 있다고 가정
      • 이 중 첫 번째로 선택되는 간선을 (u, v)라고 하자. => T는 (u, v) 포함 X
      • u와 v는 T 상에서 다른 경로로 연결되어 있을 것 => 이 경로를 이루는 간선 중 하나는 반드시 (u, v)와 가중치가 같거나 커야 함
        • 작았다면 크루스칼 알고리즘에 의해 이미 이 간선들을 모두 선택해서 (u, v)를 연결해버렸을 것 -> 애초에 (u, v)가 선택될 일도 없었을 것
      • => 이 경로 상에서 (u, v) 이상의 가중치를 갖는 간선 하나를 골라 T에서 지워 버리고 (u, v)를 추가하더라도 여전히 스패닝 트리일 것 (가중치가 줄면 줄었지 늘지는 않음)
      • => 이 속성은 마지막 간선 추가해서 스패닝 트리가 완성될 때까지 성립하므로 마지막에 얻은 트리는 항상 최소 스패닝 트리일 것
  • 최적 부분 조건 증명 => 항상 최소 간선만 선택하면 최소 스패닝 트리를 구할 수 있을 것

프림 알고리즘(Prim Algorithm)

  • 프림 알고리즘: 하나의 시작점에서 구성된 트리에 최소 가중치를 갖는 간선을 하나씩 추가하며 스패닝 트리가 될 때까지 키워나가는 방식
    • 선택된 간선들은 항상 중간 과정에서도 연결된 트리를 이루게 됨
    • 이미 만들어진 트리에 인접한 간선만을 고려한다는 점을 제외하면 원리 자체는 크루스칼 알고리즘과 동일
    • 스패닝 트리를 찾아낼 때까지 후보 간선들 중 가중치가 가장 작은 간선을 추가하는 과정을 반복
    • 후보 간선: 아직 방문하지 않은 정점을 잇는 간선 중 최소 가중치를 갖는 간선

프림 알고리즘의 구현

public class Prim {
    static int V;
    static final int INF = 987654321;
    static ArrayList<Node>[] graph;

    static int prim(List<SelectedEdge> selected) {
        // 해당 정점이 트리에 포함되어 있는지 여부를 담는 배열
        boolean[] added = new boolean[V];
        // 트리에 인접한 간선 중 해당 정점에 닿는 최소 가중치 정보를 담는 배열
        int[] minWeights = new int[V];
        // 각 정점이 트리와 연결되었는지 여부를 확인하기 위해, 사용하는 간선의 다른 한쪽 끝 정점을 담는 배열
        // ex) parent[0] = 1 -> 0번 정점이 (0, 1) 간선을 통해 트리에 연결됨
        int[] parent = new int[V];

        for (int i = 0; i < V; i++) {
            minWeights[i] = INF;
            parent[i] = -1;
        }

        // 가중치의 합
        int ret = 0;

        // 0번 정점을 시작점으로 항상 트리에 가장 먼저 추가
        minWeights[0] = parent[0] = 0;
        for (int iter = 0; iter < V; iter++) {
            // 다음에 트리에 추가할 정점을 찾음
            // 알고있는 최소 간선에서 가장 작은 값 선택
            int u = -1;
            for (int v = 0; v < V; v++) {
                if (!added[v] && (u == -1 || minWeights[u] > minWeights[v])) {
                    u = v;
                }
            }

            // (parent[u], u)를 트리에 추가
            if (parent[u] != u) selected.add(new SelectedEdge(parent[u], u));

            ret += minWeights[u];
            added[u] = true;

            // u에 인접한 간선 (u, v)들을 검사
            for (int i = 0; i < graph[u].size(); i++) {
                Node nextNode = graph[u].get(i);
                int v = nextNode.x;
                int weight = nextNode.weight;
                // 인접한 정점이 아직 추가되지 않았고, 가중치가 지금까지 찾은 해당 정점까지의 최소 간선보다 작다면 정보 업데이트
                if (!added[v] && minWeights[v] > weight) {
                    parent[v] = u;
                    minWeights[v] = weight;
                }
            }
        }

        return ret;
    }


    static class Node {
        int x;
        int weight;

        public Node(int x, int weight) {
            this.x = x;
            this.weight = weight;
        }
    }

    static class SelectedEdge {
        int u;
        int v;

        public SelectedEdge(int u, int v) {
            this.u = u;
            this.v = v;
        }
    }

    public static void main(String[] args) {
        V = 7;

        graph = new ArrayList[V];

        for (int i = 0; i < V; i++) {
            graph[i] = new ArrayList<>();
        }

        addEdge(0, 2, 1);
        addEdge(0, 1, 5);
        addEdge(1, 3, 1);
        addEdge(1, 5, 3);
        addEdge(1, 6, 3);
        addEdge(2, 3, 4);
        addEdge(3, 4, 5);
        addEdge(3, 5, 3);
        addEdge(5, 6, 2);

        List<SelectedEdge> selected = new ArrayList<>();
        int cost = prim(selected);

        System.out.println(cost);

        for (SelectedEdge selectedEdge : selected) {
            System.out.println(selectedEdge.u + " " + selectedEdge.v);
        }
    }

    private static void addEdge(int u, int v, int weight) {
        graph[u].add(new Node(v, weight));
        graph[v].add(new Node(u, weight));
    }
}
  • added[]: 각 정점이 트리에 포함되어 있는지 여부를 담는 배열
  • minWeights[]: 트리에 인접한 간선 중 해당 정점에 닿는 최소 가중치 정보를 담는 배열
    • ex) minWeights[i] = INF -> 현재 트리에서는 아직 i번 정점에 닿을 수 없음
    • ex) minWeights[i] = 8 -> 현재 트리에서 i번 정점에 닿는 최소 가중치는 8
  • parent[]: 각 정점이 트리와 연결되었는지 여부를 확인하기 위해, 사용하는 간선의 다른 한쪽 끝 정점을 담는 배열
    • ex) parent[0] = 1 -> 0번 정점이 (0, 1) 간선을 통해 트리에 연결됨
  • 한 정점에 닿는 간선이 두 개 이상일 경우 이들을 하나하나 검사하는 것은 시간 낭비
    • => 트리에 속하지 않은 각 정점에 대해, 트리와 이 정점을 연결하는 가장 짧은 간선에 대한 정보(minWeights)만 저장하고, 각 정점을 순회하면서 다음에 추가할 정점을 찾는 식(minWeights[v]가 가장 작은 v)으로 구현
  • 전체 시간 복잡도는 O(|V|^2 + |E|)
    • 일반적인 경우 크루스칼 알고리즘보다 느림
    • but, 밀집 그래프의 경우 |E|는 거의 |V|^2이므로, 시간 복잡도는 O(|V|^2) => 밀집 그래프의 경우 프림 알고리즘이 크루스칼 알고리즘보다 빠르게 동작
  • 전체 과정
    1. 먼저 minWeights를 INF, parent를 -1로 초기화
      • 처음엔 시작점(0)을 제외하고는 아무 정보도 없는 상태 => 트리에 정점이 추가될 때마다(한 iter마다) 쟁신하여 사용
      • => 다익스트라 알고리즘과 유사
    2. 시작점을 첫 iter에 추가하기 위해 minWeights[0] = parent[0] = 0
    3. 다음 동작은 정점의 개수(V)만큼 반복
      1. 다음 추가할 정점 선택 => 위 코드에서 iter 내의 지역변수 u가 이에 해당
      2. 선택된 정점을 트리에 추가
        • selected 리스트에 add
        • ret += minWeights[u]
        • added[u] = true
      3. 추가된 정점과 인접한 간선을 검사하여, minWeights[v]parent[v] 정보 업데이트
    4. 이후 가중치의 합(ret) 반환

프림 알고리즘의 다른 구현

  • 위 코드는 우선순위 큐를 사용하지 않고 다익스트라 알고리즘을 구현한 코드와 매우 유사한 형태
    • 각 정점에 대해 지금까지 알려진 최단 거리를 저장하지 않고, 마지막 간선의 가중치를 저장하고 있다는 점만 다름
  • => 다익스트라 알고리즘에서처럼 우선순위 큐를 사용하여 구현 가능
    • => O(|E|lg|V|)에 프림 알고리즘 구현 가능
    • 우선순위 큐는 각 정점의 번호를 minWeight[]이 증가하는 순으로 정렬해 담도록 함

정당성 증명

  • 프림 알고리즘은 매 반복마다 지금 만들어진 트리에 인접한 간선들만 고려한다는 점을 제외하면 크루스칼 알고리즘과 동일
  • => 프림 알고리즘의 증명 또한 크루스칼 알고리즘과 동일

알고스팟 예제 풀어보기: 근거리 네트워크(문제 ID: LAN, 난이도: 하)

알고스팟 - LAN

public class Main {

    static final int INF = 987654321;
    static int N, M;
    static double[][] graph;
    static Point[] points;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));

        int C = Integer.parseInt(br.readLine());
        while (C-- > 0) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            N = Integer.parseInt(st.nextToken()); // 건물 수
            M = Integer.parseInt(st.nextToken()); // 케이블 수

            points = new Point[N];
            graph = new double[N][N];

            st = new StringTokenizer(br.readLine());
            for (int i = 0; i < N; i++) { // 건물 x 좌표
                int x = Integer.parseInt(st.nextToken());
                points[i] = new Point(x);
            }

            st = new StringTokenizer(br.readLine());
            for (int i = 0; i < N; i++) { // 건물 y 좌표
                int y = Integer.parseInt(st.nextToken());
                points[i].setY(y);
            }

            for (int u = 0; u < N; u++) {
                for (int v = 0; v < N; v++) {
                    double dist = calcDist(points[u], points[v]);
                    graph[u][v] = graph[v][u] = dist;
                }
            }

            for (int i = 0; i < M; i++) { // 케이블 연결 정보
                // 한 쌍의 건물을 연결하는 케이블이 두 개 이상 있을 수도 있음
                st = new StringTokenizer(br.readLine());
                int u = Integer.parseInt(st.nextToken());
                int v = Integer.parseInt(st.nextToken());
                graph[u][v] = graph[v][u] = 0;
            }

            double ret = prim();
            bw.append(String.format("%.7f\n", ret));
        }

        bw.flush();
        bw.close();
        br.close();
    }

    static double prim() {
        boolean[] added = new boolean[N];
        double[] minWeights = new double[N];
        int[] parent = new int[N];

        for (int i = 0; i < N; i++) {
            minWeights[i] = INF;
            parent[i] = -1;
        }

        // 가중치의 합
        double ret = 0;

        minWeights[0] = parent[0] = 0;
        for (int iter = 0; iter < N; iter++) {
            int u = -1;
            for (int v = 0; v < N; v++) {
                if (!added[v] && (u == -1 || minWeights[u] > minWeights[v])) {
                    u = v;
                }
            }

            ret += minWeights[u];
            added[u] = true;

            for (int v = 0; v < graph[u].length; v++) {
                double weight = graph[u][v];
                if (!added[v] && minWeights[v] > weight) {
                    parent[v] = u;
                    minWeights[v] = weight;
                }
            }
        }

        return ret;
    }

    static class Node {
        int x;
        double weight;

        public Node(int x, double weight) {
            this.x = x;
            this.weight = weight;
        }
    }

    static class Edge implements Comparable<Edge> {
        double weight;
        int u;
        int v;

        public Edge(double weight, int u, int v) {
            this.weight = weight;
            this.u = u;
            this.v = v;
        }

        @Override
        public int compareTo(Edge o) {
            return Double.compare(weight, o.weight);
        }
    }

    static class Point {
        int x, y;

        public Point(int x) {
            this.x = x;
        }

        public void setY(int y) {
            this.y = y;
        }
    }

    private static double calcDist(Point a, Point b) {
        return Math.sqrt(Math.pow(b.x - a.x, 2) + Math.pow(b.y - a.y, 2));
    }
}
  • 처음 주어진 그래프에서 각 컴포넌트들을 정점 하나로 압축한 뒤, 모든 간선을 연결하여 이에 대해 최소 스패닝 문제 알고리즘을 적용하는 방식
  • => 각 컴포넌트를 하나의 정점으로 압축하는 효과를 얻기 위해 이미 연결된 건물의 가중치를 0으로 둠
  • 알고리즘을 수행할 대상 그래프가 밀집 그래프이므로 프림이 더 효율 좋을 것이라고 생각되어 프림 알고리즘 사용
  • 각 정점과 인접한 정점을 빠르게 조회하기 위해 인접 리스트가 아닌 인접 행렬 사용
728x90
반응형