0. 세그먼트 트리(Segment Tree)란?
세그먼트 트리(Segment Tree)
어떤 데이터가 주어질 때, 특정 구간의 결과값(구간합, 최댓값 등)을 구하는데 사용하는 자료구조
세그먼트 트리는 이진 트리(Binary Tree) 구조를 가지고 있으며, 특정 구간의 결과값을 시간복잡도 O(logN)으로 빠르게 구할 수 있다는 장점이 있다.
또한, 자료의 수정이 빈번히 일어날 때, 사용한다.
이 게시글에서는 누적합을 구하는 세그먼트 트리 기준으로 설명을 한다.
만약 구간 최대나 구간 최소를 알고 싶다면, 반환하는 부분만 수정하면 된다.
1. 세그먼트 트리 알아보기
1-1. 세그먼트 트리 생성(초기화)
아래는 [5, 7, 6, 3, 1, 9]인 원본 배열로부터 만든 세그먼트 트리이다.
파란 글씨는 해당 노드가 가리키는 구간의 범위를 나타낸다.
또한, 세그먼트 트리의 값은 배열을 이용하여 저장하며, 보통 크기는 4N개로 선언한다.
해당 그림의 결과를 배열로 표현하면 다음과 같다.
sum = [_, 31, 18, 13, 12, 6, 4, 9, 5, 7, _, _, 3, 1, _, _]
// _는 사용하지 않는 칸
1-2. 배열의 값 변경(업데이트)
3번 위치의 값을 100으로 바꾸고 싶은 경우, 아래처럼 수정하면 된다.
1-3. 세그먼트 트리 구간합 계산(쿼리)
2부터 5까지 범위의 구간합을 구하고 싶다면, 색칠된 값만 알면 된다.
1-4. 세그먼트 트리의 크기가 4N인 이유
N개의 배열에서 최대 깊이는 logN이다.
logN 깊이에서의 노드의 개수는 최대 2N개이다.
그럼 총 노드의 개수는 1 + 2 + 4 + ... + N + 2N이 되며,
(1 + 2 + 4 + ... + N) + 2N = (2N - 1) + 2N = 4N - 1
이 과정에 따라 4N개로 선언한다.
1-4-1. 1 + 2 + 4 + ... + N = 2N - 1인 이유
하나씩 해봐도 쉽게 알 수 있지만, 원리를 이해하기 위해서는 비트로 표현하는 것이 제일이다.
1 + 2 + 4 = 7 = 8 - 1, 이 식을 2진수로 표현하면 001 + 010 + 100 = 111 = 1000 - 1이다.
즉, 1부터 2의 배수로 커지는 값의 합은, 마지막 항을 x라고 했을 때 2x - 1이 된다.
따라서, 1부터 N까지 2배씩 커지는 등비수열의 합은 2N - 1이 되는 것이다.
2. 세그먼트 트리 연산
세그먼트 트리의 연산은 앞서 봤듯 초기화, 업데이트, 쿼리로 총 3가지이다.
모두 bottom-up 방식으로 전체 범위에서 시작하여 절반의 범위씩 줄여가는 재귀 방식으로 구현한다.
2-1. 초기화
초기에 주어진 데이터를 활용해 세그먼트 트리를 생성하기 위한 연산으로 시간복잡도는 O(NlogN)이다.
- 현재 범위의 크기가 1이라면,
- 현재 노드 위치에 현재 범위에 해당하는 값 저장
- 해당 값 반환
- 현재 범위의 크기가 1이 아니라면,
- 왼쪽 절반 범위의 구간합을 구하는 초기화 연산 호출
- 오른쪽 절반 범위의 구간합을 구하는 초기화 연산 호출
- 두 호출 결과를 더한 값 저장 후 반환
2-2. 업데이트
특정 위치의 값이 바뀔 경우, 해당 위치를 포함하는 범위의 값을 수정하기 위한 연산으로 시간복잡도는 O(logN)이다.
- 현재 범위가 수정할 위치를 포함하지 않는다면,
- 현재 노드 값 반환
- 현재 범위의 크기가 1이라면,
- 현재 노드 위치에 수정하려는 값 저장
- 해당 값 반환
- 현재 범위의 크기가 1이 아니라면,
- 왼쪽 절반 범위의 구간합을 수정하는 업데이트 연산 호출
- 오른쪽 절반 범위의 구간합을 수정하는 업데이트 연산 호출
- 두 호출 결과를 더한 값 저장 후 반환
2-3. 쿼리
특정 구간의 구간합을 구하기 위한 연산으로 시간복잡도는 O(logN)이다.
- 현재 범위가 쿼리 범위를 아예 포함하지 않는다면
- 0을 리턴
- 현재 범위가 쿼리 범위의 안에 완벽하게 포함된 경우,
- 현재 노드 값 반환
- 현재 범위가 쿼리 범위의 일부에 포함된 경우,
- 왼쪽 절반 범위의 구간합을 구하는 쿼리 연산 호출
- 오른쪽 절반 범위의 구간합을 구하는 쿼리 연산 호출
- 두 호출 결과를 더한 값 반환
3. 코드
public class Main {
static final int ARRAY_SIZE = 6;
static final int NONE = -1;
static long[] array = {NONE, 5, 7, 6, 3, 1, 9};
static long[] sum = new long[ARRAY_SIZE * 4 + 1];
public static long init(int left, int right, int node) {
if (left == right)
return sum[node] = array[left];
int mid = (left + right) / 2;
return sum[node] = init(left, mid, node * 2) + init(mid + 1, right, node * 2 + 1);
}
public static long update(int idx, long value, int left, int right, int node) {
if (idx < left || idx > right)
return sum[node];
if (left == right)
return sum[node] = array[idx] = value;
int mid = (left + right) / 2;
return sum[node] = update(idx, value, left, mid, node * 2) + update(idx, value, mid + 1, right, node * 2 + 1);
}
public static long query(int queryLeft, int queryRight, int left, int right, int node) {
if (right < queryLeft || queryRight < left)
return 0;
if (queryLeft <= left && right <= queryRight)
return sum[node];
int mid = (left + right) / 2;
return query(queryLeft, queryRight, left, mid, node * 2) + query(queryLeft, queryRight, mid + 1, right, node * 2 + 1);
}
public static void main(String[] args) throws Exception {
init(1, ARRAY_SIZE, 1);
System.out.println("= 구간합 초기화");
for (int idx = 1; idx <= 16; idx++)
System.out.print(sum[idx] + " ");
System.out.println();
System.out.println();
System.out.println("= 1~3 사이의 구간합: " + query(1, 3, 1, ARRAY_SIZE, 1));
System.out.println();
update(3, 100, 1, ARRAY_SIZE, 1);
System.out.println("= 3번 위치 값 100로 수정");
for (int idx = 1; idx <= 16; idx++)
System.out.print(sum[idx] + " ");
System.out.println();
System.out.println();
System.out.println("= 1~3 사이의 구간합: " + query(1, 3, 1, ARRAY_SIZE, 1));
}
}