세그먼트 트리(Segment Tree)
트리 자료구조에는 여러 가지가 있지만 그 중, 저장된 자료들을 적절히 전처리하여 그들에 대한 질의들을 빠르게 대답할 수 있도록 하는 트리를 '세그먼트 트리(Segment Tree)'라고 한다.
세그먼트 트리는 흔히 일차원 배열의 특정 구간에 대한 질문을 빠르게 대답하는 데 사용한다. 가장 간단하면서도 흔히 사용되는 예로 구간의 최소치를 구하는 문제인 '구간 최소 쿼리(RMQ, Range Minimum Query)'가 있다. 어떤 배열 A의 부분 구간의 최소치를 구하는 연산을 '여러 번'하고 싶다고 하자. 예를 들어 A = {1, 2, 1, 2, 3, 1, 2, 3, 4}라면 [2, 4] 구간의 최소치는 1이고, [6, 8] 구간의 최소치는 2이다. 이 연산은 구간이 주어질 때마다 해당 배열을 순회하며 최소치를 찾는 식으로 O(n) 시간에 해결할 수 있지만, 세그먼트 트리를 이용하면 O(lgN) 시간 안에 이를 해결할 수 있다.
세그먼트 트리의 기본 아이디어
- 세그먼트 트리의 기본저긴 아이디어는 주어진 배열의 구간(세그먼트)들을 표현하는 '이진 트리'를 만드는 것
- 세그먼트 트리의 루트는 항상 배열의 전체 구간 [0, n-1]을 표현
- 한 트리의 왼쪽 자식과 오른쪽 자식은 각각 해당 구간의 왼쪽 반과 오른쪽 반을 표현한다
- 길이가 1인 구간을 표현하는 노드들은 세그먼트 트리의 리프 노드가 될 것이다.
위 그림은 배열의 길이가 15인 세그먼트 트리의 각 노드가 표현하는 구간들을 보여준다. 세그먼트 트리는 노드마다 해당 구간에 대한 계산 결과(전처리 결과)를 저장해둔다. 에를 들면, RMQ를 풀기위한 구간의 최소치를 구하는 세그먼트 트리는 해당 구간의 최소치를 각 노드에 저장한다.
이러한 전처리 과정을 수행해두면 어떤 구간이 주어지더라도 이 구간을 세그먼트 트리의 노드에 포함된 구간들의 합집합으로 표현할 수 있다. 예를 들어, 길이가 15인 배열에서 [6, 12] 구간은 위 그림에서 빨갛게 칠해진 세 구간 [6, 7]과 [8, 11], [12, 12]의 합집합으로 표현할 수있다. 이때 각 구간의 최소치는 구간마다 미리 계산(전처리)해두었으니 이 셋 중의 최소치가 우리가 원하는 답이 된다.
어떤 구간이 주어지건 간에 답을 찾기 위해 우리가 고려해야 하는 구간의 수는 O(lgN)이 되고, 그에 따라 수행 시간 역시 O(lgN)이 되는 것이다.
세그먼트 트리의 표현
구간 트리의 설명을 위해 위에서 언급한 RMQ 문제를 풀어보자!
세그먼트 트리는 포화 이진 트리에 가깝다. 이러한 포화 이진 트리는 포인터로 구현하기보다는 배열을 통해 구현하는 것이 메모리를 더 절약할 수 있고 구현도 간단하다. 루트 노드를 1번 원소로, 노드 i의 왼쪽 자손과 오른쪽 자손을 각각 2*i
와 2*i+1
번 원소로 표현하도록 하자. 그리고 배열의 길이는 n에 4를 곱한만큼으로 설정하자.
세그먼트 트리의 초기화
배열 int[] array
가 주어질 때, 각 노드마다 해당 구간의 최소치를 구하는 계산하는 함수 init()
을 구현한 코드는 다음과 같다. init()은 현재 구간을 두 개로 나눠 재귀 호출한 뒤, 두 구간의 최소치 중 더 작은 값을 선택해 해당 구간의 최소치를 계산한다.
public class RMQ {
private int n; // 배열의 길이
private int[] rangeMin;
public RMQ(int[] array) {
n = array.length;
rangeMin = new int[n * 4];
init(array, 0, n - 1, 1);
}
// node 노드가 array[left..right] 배열을 표현할 때
// node를 루트로 하는 서브트리를 초기화하고, 이 구간의 최소치를 반환한다.
private int init(int[] array, int left, int right, int node) {
if (left == right) return rangeMin[node] = array[left];
int mid = (left + right) / 2;
int leftMin = init(array, left, mid, node * 2);
int rightMin = init(array, mid + 1, right, node * 2 + 1);
return rangeMin[node] = Math.min(leftMin, rightMin);
}
}
이러한 초기화 과정은 각 노드마다 O(1) 시간이 걸리고, 노드의 수는 n개 이므로 시간복잡도는 O(n)이다.
세그먼트 트리의 질의 처리
이번엔 임의의 구간의 최소치를 구해보자. 이를 세그먼트 트리에서의 질의(query) 연산이라고 부른다. 질의 연산을 수행하는 query()
함수는 순회를 응용하여 간단하게 구현할 수 있다. query()는 우선 node가 표현하는 구간 [nodeLeft, nodeRight]와 최소치를 찾을 구간 [left, right]의 교집합을 구한 뒤, 그에 따라 서로 다른 값을 반환한다.
- 교집합이 공집합인 경우: 두 구간이 서로 겹치지 않으므로 반환값은 존재하지 않는다. 반환값을 무시하기 위해 아주 큰 값을 반환하도록 하자.
- 교칩합이 [nodeLeft, nodeRight]인 경우: [left, right]가 node가 표현하는 집합을 완전히 포함하는 경우이다. 이 경우엔 단순히 node에 계산해 둔 최소치를 반환하면 된다.
- 이 외의 모든 경우: 두 개의 자손 노드에 대해 query()를 재귀 호출한 뒤, 이 두 값 중 더 작은 값을 택해 반환한다.
public class RMQ {
private static int INT_MAX = 987654321;
... 생략 ...
// 외부 인터페이스
public int query(int left, int right) {
return query(left, right, 1, 0, n - 1);
}
/**
* 구간 트리의 질의 처리
* node가 표현하는 범위 array[nodeLeft..nodeRight]가 주어질 때, 이 범위와 array[left..right]의 교집합의 최소치를 구한다
* O(lgN)
*
* @param left
* @param right
* @param node
* @param nodeLeft
* @param nodeRight
* @return
*/
private int query(int left, int right, int node, int nodeLeft, int nodeRight) {
// 두 구간이 겹치지 않으면 아주 큰 값을 반환하여 무시함
if (right < nodeLeft || left > nodeRight) return INT_MAX;
// node가 표현하는 범위가 array[left..right]에 완전히 포함되는 경우
if (left <= nodeLeft && right >= nodeRight) return rangeMin[node];
// 양쪽 구간을 나눠서 푼 뒤 결과를 합친다.
int mid = (nodeLeft + nodeRight) / 2;
return Math.min(query(left, right, node * 2, nodeLeft, mid),
query(left, right, node * 2 + 1, mid + 1, nodeRight));
}
}
- 어떤 구간을 반으로 쪼갠 후에는 반드시 한 쪽은 '교집합이 [nodeLeft, nodeRight]인 경우'가 발생하므로, 전체 시간 복잡도는 O(lgN)이다.
만약 구간에서의 최소치를 구하는 것이 아니라, 최대치를 구하는 문제였다면 Math.min이 아니라 Math.max를 사용하면 될 것이고, 구간 합을 구하는 문제였다면 단순히 왼쪽 구간의 재귀 호출 결과와 오른쪽 구간의 재귀 호출 결과를 더해주면 됐을 것이다.
세그먼트 트리의 갱신
세그먼트 트리는 O(lgN) 시간에 요소 1개를 갱신할 수 있다. 갱신 과정은 query()와 init()을 합친 것처럼 구현된다. 동작은 다음과 같다.
- 해당 노드가 표현하는 구간 [nodeLeft, nodeRight]에 index가 포함되지 않는다면 무시한다.
- 포함된다면 재귀 호출을 통해 두 자손 구간의 최소치를 계산한 뒤, 다시 최소치를 구해준다.
public class RMQ {
// 외부 인터페이스
public int update(int index, int newValue) {
return update(index, newValue, 1, 0, n - 1);
}
/**
* 구간 트리의 갱신
* array[index] = newValue로 바뀌었을 때 node를 루트로 하는 구간 트리를 갱신하고 노드가 표현하는 구간의 최소치를 반환한다.
* O(lgN)
*
* @param index
* @param newValue
* @param node
* @param nodeLeft
* @param nodeRight
* @return
*/
private int update(int index, int newValue, int node, int nodeLeft, int nodeRight) {
// index가 노드가 표현하는 구간과 상관없는 경우 무시한다
if (index < nodeLeft || index > nodeRight) return rangeMin[node];
// 트리의 리프까지 내려온 경우, 갱신 후 반환한다
if (nodeLeft == nodeRight) return rangeMin[node] = newValue;
// 재귀 호출을 통해 최소치를 구한 후 수정
int mid = (nodeLeft + nodeRight) / 2;
return rangeMin[node] = Math.min(update(index, newValue, node * 2, nodeLeft, mid),
update(index, newValue, node * 2 + 1, mid + 1, nodeRight));
}
}
완성 코드
public class RMQ {
private static int INT_MAX = 987654321;
private int n; // 배열의 길이
private int[] rangeMin;
public RMQ(int[] array) {
n = array.length;
rangeMin = new int[n * 4];
init(array, 0, n - 1, 1);
}
// node 노드가 array[left..right] 배열을 표현할 때
// node를 루트로 하는 서브트리를 초기화하고, 이 구간의 최소치를 반환한다.
private int init(int[] array, int left, int right, int node) {
if (left == right) return rangeMin[node] = array[left];
int mid = (left + right) / 2;
int leftMin = init(array, left, mid, node * 2);
int rightMin = init(array, mid + 1, right, node * 2 + 1);
return rangeMin[node] = Math.min(leftMin, rightMin);
}
public int query(int left, int right) {
return query(left, right, 1, 0, n - 1);
}
/**
* 구간 트리의 질의 처리
* node가 표현하는 범위 array[nodeLeft..nodeRight]가 주어질 때, 이 범위와 array[left..right]의 교집합의 최소치를 구한다
* O(lgN)
*
* @param left
* @param right
* @param node
* @param nodeLeft
* @param nodeRight
* @return
*/
private int query(int left, int right, int node, int nodeLeft, int nodeRight) {
// 두 구간이 겹치지 않으면 아주 큰 값을 반환하여 무시함
if (right < nodeLeft || left > nodeRight) return INT_MAX;
// node가 표현하는 범위가 array[left..right]에 완전히 포함되는 경우
if (left <= nodeLeft && right >= nodeRight) return rangeMin[node];
// 양쪽 구간을 나눠서 푼 뒤 결과를 합친다.
int mid = (nodeLeft + nodeRight) / 2;
return Math.min(query(left, right, node * 2, nodeLeft, mid),
query(left, right, node * 2 + 1, mid + 1, nodeRight));
}
public int update(int index, int newValue) {
return update(index, newValue, 1, 0, n - 1);
}
/**
* 구간 트리의 갱신
* array[index] = newValue로 바뀌었을 때 node를 루트로 하는 구간 트리를 갱신하고 노드가 표현하는 구간의 최소치를 반환한다.
* O(lgN)
*
* @param index
* @param newValue
* @param node
* @param nodeLeft
* @param nodeRight
* @return
*/
private int update(int index, int newValue, int node, int nodeLeft, int nodeRight) {
// index가 노드가 표현하는 구간과 상관없는 경우 무시한다
if (index < nodeLeft || index > nodeRight) return rangeMin[node];
// 트리의 리프까지 내려온 경우, 갱신 후 반환한다
if (nodeLeft == nodeRight) return rangeMin[node] = newValue;
int mid = (nodeLeft + nodeRight) / 2;
return rangeMin[node] = Math.min(update(index, newValue, node * 2, nodeLeft, mid),
update(index, newValue, node * 2 + 1, mid + 1, nodeRight));
}
public int[] queryTwoElements(int left, int right) {
return queryTwoElements(left, right, 1, 0, n - 1);
}
public static void main(String[] args) {
int[] array = {0, 3, 2, 7, 9, 11, 4, 6};
RMQ rmq = new RMQ(array);
// 테스트 1: 구간 최소 쿼리
System.out.println("Minimum in range [1, 4]: " + rmq.query(1, 4)); // Expected: 2
System.out.println("Minimum in range [0, 7]: " + rmq.query(0, 7)); // Expected: 0
// 테스트 2: 업데이트 후 구간 최소 쿼리
rmq.update(2, 0);
System.out.println("After update, minimum in range [1, 4]: " + rmq.query(1, 4)); // Expected: 0
// 테스트 3: 두 최소값 구하기
System.out.println("Two minimums in range [0, 7]: " + Arrays.toString(rmq.queryTwoElements(0, 7))); // Expected: [0, 1]
}
}
예제: 정렬된 수열의 특정 구간에서 최대 출현 빈도 계산 (RangeResult)
이번에는 세그먼트 트리를 사용하여 정렬된 정수 수열 A[]가 주어질 때, 주어진 구간의 '최대 출현 빈도'를 계산하는 문제를 풀어보자. 어떤 수열의 최대 출현 빈도란 이 수열에서 가장 자주 등장하는 수의 출현 횟수이다. 예를 들어 A={0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4}라면 A[1] ~
A[5]까지 최대 출현 빈도는 1이 4번 등장하므로 4이다. A[5]~
A[7]에서는 1,2,3이 모두 한 번씩 출현하므로 1이다.
이 문제에서는 지금까지 처럼 두 개의 부분 구간에 대해 푼 결과를 합치는 것이 쉽지않다.. 두 구간 중 최댓값만을 취하고 싶지만 다음과 같은 예외 상황이 존재한다.
- 두 구간에서 가장 많이 출현하는 수가 같고, 두 구간을 이었을 때 수열이 이어지는 경우: 두 답 중 더 큰 쪽을 택하는 것이 아니라, 두 답의 합을 반환한다.
- 두 구간을 이어 보면 가장 많이 출현하는 수가 바뀌는 경우: 양쪽 부분에서 가장 많이 출현하는 수가 각각 1과 3이었지만, 두 국나을 합쳐보면 2가 가장많이 출현하게 될 수도 있다.
위와 같은 예외 상황으로 인해 단순히 Math.max(왼쪽, 오른쪽)
과 같은 방식은 통하지 않는다.. 이런 문제를 풀 때는 문제의 답만이 아니라 두 개의 답을 합치는 데 필요한 추가 정보도 계산해서 반환할 필요가 있다. 왼쪽 구간의 오른쪽 끝 수가 오른쪽 구간의 왼쪽 끝 수와 같다는 데에 착안하여, 모든 구간에 대해 답을 계산할 때 구간의 왼쪽 끝 수(leftNumber
)와 오른쪽 끝 수(rightNumber
), 그리고 그들의 갯수(leftFreq
, rightFreq
)를 함께 계산하자. 그리고 이 두 부분 구간을 합치려면, 왼쪽 구간의 오른쪽 끝 숫자와 오른쪽 구간의 왼쪽 끝 숫자가 같은 경우 이들을 합쳐 보고, 최대 출현 빈도가 바뀌는지 확인하면 된다.
이를 구현한 코드는 다음과 같다.
public class RangeResult {
private int size;
private int mostFrequent; // 가장 자주 등장하는 숫자의 출현 횟수
private int leftNumber, leftFreq; // 왼쪽 끝 숫자와 왼쪽 끝 숫자의 출현 횟수
private int rightNumber, rightFreq; // 오른쪽 끝 숫자와 오른쪽 끝 숫자의 출현 횟수
// 왼쪽 부분 구간의 계산 결과 a, 오른쪽 부분 구간의 계산 결과 b를 합친다.
public RangeResult merge(RangeResult a, RangeResult b) {
RangeResult ret = new RangeResult();
ret.size = a.size + b.size;
// 왼쪽 부분 구간이 전부 a.leftNumber인 경우
// ex) [1, 1, 1, 1]과 [1, 2, 2, 2]를 합칠 때
ret.leftNumber = a.leftNumber;
ret.leftFreq = leftFreq;
if (a.size == a.leftFreq && a.leftNumber == b.leftNumber) {
ret.leftFreq += b.leftFreq;
}
// 오른쪽 끝 숫자도 비슷하게 계산
ret.rightNumber = b.rightNumber;
ret.rightFreq = b.rightFreq;
if (b.size == b.rightFreq && a.rightNumber == b.rightNumber) {
ret.rightFreq += a.rightFreq;
}
// 기본적으로 가장 많이 출현하는 수의 빈도수는 둘 중 큰 쪽으로
ret.mostFrequent = Math.max(a.mostFrequent, b.mostFrequent);
// 왼쪽 구간의 오른쪽 끝 숫자와 오른쪽 구간의 왼쪽 끝 숫자가 합쳐지는 경우
// 이 두수를 합쳤을 때 mostFrequent보다 커지는지 확인한다
if (a.rightNumber == b.leftNumber) {
ret.mostFrequent = Math.max(ret.mostFrequent, a.rightFreq + b.leftFreq);
}
return ret;
}
}
백준 문제 풀어보기: 구간 합 구하기(G1)
import java.io.*;
import java.util.StringTokenizer;
public class BOJ2042 {
private static long[] rangeSum;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int M = Integer.parseInt(st.nextToken());
int K = Integer.parseInt(st.nextToken());
long[] tempArr = new long[N];
rangeSum = new long[N * 4];
for (int i = 0; i < N; i++) {
tempArr[i] = Long.parseLong(br.readLine());
}
init(tempArr, 0, N - 1, 1);
for (int i = 0; i < M + K; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
long c = Long.parseLong(st.nextToken());
if (a == 1) update(b - 1, c, 1, 0, N - 1);
else if (a == 2) bw.append(query(b - 1, (int) c - 1, 1, 0, N - 1) + "\n");
}
bw.flush();
bw.close();
br.close();
}
public static long init(long[] array, int left, int right, int node) {
if (left == right) return rangeSum[node] = array[left];
int mid = (left + right) / 2;
long leftSum = init(array, left, mid, node * 2);
long rightSum = init(array, mid + 1, right, node * 2 + 1);
return rangeSum[node] = leftSum + rightSum;
}
public static long query(int left, int right, int node, int nodeLeft, int nodeRight) {
if (right < nodeLeft || left > nodeRight) return 0;
if (left <= nodeLeft && right >= nodeRight) return rangeSum[node];
int mid = (nodeLeft + nodeRight) / 2;
return query(left, right, node * 2, nodeLeft, mid) +
query(left, right, node * 2 + 1, mid + 1, nodeRight);
}
public static long update(int index, long newValue, int node, int nodeLeft, int nodeRight) {
if (index < nodeLeft || index > nodeRight) return rangeSum[node];
if (nodeLeft == nodeRight) return rangeSum[node] = newValue;
int mid = (nodeLeft + nodeRight) / 2;
return rangeSum[node] = update(index, newValue, node * 2, nodeLeft, mid) +
update(index, newValue, node * 2 + 1, mid + 1, nodeRight);
}
}
이진 인덱스 트리(BIT, Binary Inedx Tree)
위 문제에서 처럼 세그먼트 트리의 가장 흔한 사용 예는 구간 합을 빠르게 구하는 것이다. 하지만, 이 경우 세그먼트 트리 대신 쓸 수 있는 세그먼트 트리의 궁극적인 진화 형태로 '펜윅 트리(Fenwick Tree)' 혹은 '이진 인덱스 트리(Binary Index Tree, BIT)'라고 불리는 것이 있다!
부분 합(partial sum): 배열 A의 첫 몇 i개의 원소의 합
구간 합(range sum): A의 연속된 부분 배열의 합
즉, 부분 합은 첫 위치가 A[0]으로 고정된 구간 합이다!
이진 인덱스 트리의 원리
이진 인덱스 트리가 사용하는 중요한 아이디어는 구간 합(range sum) 대신 부분 합(partial sum)만을 빠르게 계산할 수 있다면, 구간 합을 계산할 수 있다는 것이다. 배열 A의 위치 pos에 대해 부분 합 psum[pos] = A[0] + A[1] + ... + A[pos]
를 빠르게 계산할 수 있다면, [i, j] 구간 합은 psum[j] - psum[i-1]
로 계산할 수 있다. 즉, 부분합을 통해 구간합을 빠르게 게산할 수 있다!
그렇다면 우리는 세그먼트 트리의 각 원소에 '구간 합'을 저장한다고 생각해 보면, 세그먼트 트리가 미리 계산해 저장하는 정보의 상당수는 필요가 없다.. 위 그림은 길이 16인 배열의 구간 합을 구하기 위해 세그먼트 트리가 계산해 저장하는 각 구간들을 보여준다. 이는 위 백준 예제를 풀 때의 형태와 동일한 세그먼트 트리이다. 하지만 잘 생각해보면, [8, 15] 구간의 구간 합은 사실 부분 합만을 구한다면 필요가 없다. 이는 단순히 psum[15] - psum[7]을 계산함으로써 구할 수 있기 때문이다. 같은 원리로, 하나의 긴 구간 밑에 두 개의 작은 구간이 있을 때 이 두 구간 중 오른쪽 구간은 항상 저장할 필요가 없다.
위 그림에서 아래쪽 그림은 이 구간 중 필요한 부분만 남긴 결과를 보여준다. 남은 구간의 수는 정확하게 n개가 된다. 또한, 각 구간이 포함하는 오른쪽 끝 원소들을 보면 이들이 서로 모두 다르다는 것을 알 수 있다. 이진 인덱스 트리는 이 대응을 이용해 1차원 배열 하나에 각 구간 합을 저장한다.
tree[i] = 위 그림의 아래쪽 그림에서 오른쪽 끝 위치가 A[i]인 구간의 합
이렇게 각 구간 합을 저장해두면 부분 합은 어떻게 계산할 수 있을까?
A[0]~A[pos]
까지의 부분 합 psum[pos]
를 게산하고 싶으면 tree[pos]를 답에 더한 후, 남은 부분들을 왼쪽에서 찾아 더하면 된다. 예를 들면, psum[12]
는 tree[12] + tree[11] + tree[7]
을 통해, 즉 [12, 12]
, [8, 11]
, [0, 7]
의 구간 합들을 더함으로써 구할 수 있다. 이를 통해 어떤 부분 합을 구하든 O(lgN) 개의 구간 합만 있으면 된다는 것을 알 수 있다.
그렇다면 이제 문제는 pos에서 끝나는 구간 다음으로 더해야 할 구간을 어떻게 찾을까 하는 것인데, 이진 인덱스 트리는 각 숫자의 이진수 표현을 이용해 이 문제를 해결한다.
정수에 따른 2진수 표기는 본 포스팅에서 설명하지 않겠다. 음수를 표현할 때 2의 보수법을 사용한다는 것을 기억하길 바란다. 이를 이해한다면 우리는 쉽게 '0이 아닌 마지막 비트'를 찾을 수 있다. 특정한 숫자 K의 0이 아닌 마지막 비트를 찾기 위해서는 K & -K
연산을 수행하면 된다. 다음은 K & -K 연산 결과 예시이다.
이를 활용하기 위해 우선, 배열 A[]와 tree[]의 첫 원소의 인덱스를 1로 바꾸자! 즉, 모든 원소의 인덱스에 1을 더해주자. 그러고 난 후, 0이 아닌 마지막 비트를 이용하면 특정 부분 합을 구하기 위해 더해야 할 구간 합들을 쉽게 찾을 수 있다.
위 그림처럼 각 구간의 길이, 즉 각 구간이 저장하고 있는 값들의 개수는 2^(0이 아닌 마지막 비트)
가 된다. 리프에서 시작해서 한 층씩 위로 올라갈 때마다 구간의 길이는 2배가 된다. (그림에서 숫자 밑의 이진수는 0이 아닌 마지막 비트를 나타낸 것이다)
부분 합을 구하기 위해 더해야 하는 구간들의 번호도 이들의 이진수 표현과 관계가 있다. psum[7] = tree[7] + tree[6] + tree[4]
임을 통해서, 현재 pos의 이진수 표현에서 0이 아닌 마지막 비트만큼 뺀 위치로 이동해가며 구간 합을 더한다면 부분합을 계산할 수 있음을 알 수 있다.
7, 6, 4를 각각 2진수로 표현하면, 111, 110, 100이다. 111에서 0이 아닌 마지막 비트인 1을 빼면 110이 되고, 110에서 0이 아닌 마지막 비트인 10을 빼면 100이 된다.
이번엔 이진 인덱스 트리에서 배열의 값을 갱신하는 연산도 확인해보자. 이는 해당 위치의 값에 숫자를 더하고 빼는 방식으로 구현한다. 예를 들어, A[5]를 3 늘리고 싶다면, A[5]를 포함하는 모든 구간의 구간 합들을 3씩 늘려주면 된다. 이때 늘려줘야 할 값들은 tree[5], tree[6], tree[8], tree[16]으로, 각 인덱스의 이진수 표현은 101
, 110
, 1000
, 10000
이다. 이 역시 해당 위치의 이진수 표현에서 0이 아닌 마지막 비트를 더해주는 것을 반복하여 다음 위치를 구할 수 있다.
이를 취합하여 이진 인덱스 트리를 구현한 코드는 다음과 같다.
이진 인덱스 트리의 구현
// 펜윅 트리의 구현. 가상의 배열 A[]의 부분합을 빠르게 구할 수 있도록 한다.
// 초기화시에는 A[]의 원소가 전부 0이라고 생각한다.
public class FenwickTree {
int[] tree;
FenwickTree(int N) {
tree = new int[N + 1];
}
/**
* A[0..pos]의 부분 합을 구한다
* O(lgN)
*
* @param pos
* @return
*/
public int prefixSum(int pos) {
int ret = 0;
while (pos > 0) {
ret += tree[pos];
pos -= (pos & -pos); // 다음 구간을 찾기 위해 최종 비트를 지운다
}
return ret;
}
/**
* A[pos]에 val을 더한다.
* O(lgN)
*
* @param pos
* @param val
*/
public void update(int pos, int val) {
while (pos < tree.length) {
tree[pos] += val;
pos += (pos & -pos);
}
}
/**
* A[start..end]까지의 구간 합을 구한다. psum[end] - psum[start - 1]
*
* @param start
* @param end
* @return
*/
public int intervalSum(int start, int end) {
return prefixSum(end) - prefixSum(start - 1);
}
}
부분 합을 구하는 연산과 배열의 값을 갱신하는 연산의 시간 복잡도는 모두 O(lgN)이다.
이렇게 이진 인덱스 트리의 구현은 매우 간결하다! 때문에 "게속 변하는 배열의 구간 합을 구할 때"는 구간 트리보다 이진 인덱스 트리를 훨씬 자주 쓰게 된다.
백준 문제 풀어보기: 구간 합 구하기(G1)
위에서 세그먼트 트리로 풀었던 예제를 이진 인덱스 트리를 통해 다시 풀어보자!
import java.io.*;
import java.util.StringTokenizer;
public class BOJ2042_2 {
private static int N, M, K;
private static long[] arr, tree;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());
M = Integer.parseInt(st.nextToken());
K = Integer.parseInt(st.nextToken());
arr = new long[N + 1];
tree = new long[N + 1];
for (int i = 1; i < N + 1; i++) {
long x = Long.parseLong(br.readLine());
arr[i] = x;
update(i, x);
}
for (int i = 0; i < M + K; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
long c = Long.parseLong(st.nextToken());
if (a == 1) {
update(b, c - arr[b]);
arr[b] = c;
} else if (a == 2) bw.append(intervalSum(b, (int) c) + "\n");
}
bw.flush();
bw.close();
br.close();
}
public static long prefixSum(int pos) {
long ret = 0;
while (pos > 0) {
ret += tree[pos];
pos -= (pos & -pos);
}
return ret;
}
public static void update(int pos, long val) {
while (pos <= N) {
tree[pos] += val;
pos += (pos & -pos);
}
}
public static long intervalSum(int start, int end) {
return prefixSum(end) - prefixSum(start - 1);
}
}
'Data Structure' 카테고리의 다른 글
[Data Structure] 그래프의 정의와 표현 (2) | 2024.12.08 |
---|---|
[Data Structure] 유니온 파인드 (3) | 2024.12.05 |
[Data Structure] 트리, 이진 검색 트리, 우선 순위 큐와 힙 (1) | 2024.12.02 |
[Data Structure] 큐와 스택, 덱 (1) | 2024.11.27 |
[Data Structure] 선형 자료 구조(동적 배열 & 연결 리스트) (1) | 2024.11.27 |