본문 바로가기
알고리즘

펜윅 트리 쉽게 이해하기.

by 강성주의 알고리즘 2023. 7. 21.

알고리즘에 입문을 해서 문제를 풀다 보면 구간 합과 관련된 문제를 접하게 된다. 우리는 구간 합 문제를 O(N) 복잡도의 누적 합을 구하고 특정 구간의 원소들의 합을 O(1) 만에 구할 수 있는 방법을 습득하게 된다.
누적 합이 몇인지 묻는 질의(또는 쿼리)가 M개가 주어진다면, 우리는 O(M)의 복잡도로 문제를 빠르게 해결할 수 있다. 만약, 위와 같은 질의 중간중간에 원소의 값이 바뀌는 새로운 질의가 주어진다면 어떻게 해결해야 할까? 매번 O(N)을 수행해서 누적 합을 저장하는 배열의 값을 업데이트해주어야 할까? 그럼 O(NM)의 복잡도로 시간 초과를 받게 될 것이다.
구간 트리 또는 세그먼트 트리(segment tree)라는 것이 있지만 누적 합과 관련된 문제에서 좀 더 간단하게 구현할 수 있는 펜윅 트리(fenwick tree)라는 것을 소개하려 한다. 나는 이미 세그먼트 트리를 완벽하게 구현할 수 있다 하는 사람들도 뒤로 가기를 누르면 안 된다 ㅠㅠ.
펜윅 트리는 N개의 수를 길이 N의 배열에 구간 합을 저장할 수 있는 특수한 트리이다. Binary Indexed Tree, BIT라고도 한다는데 용어는 중요한 게 아니니.. 사실 펜윅 발음도 어렵고, 필자는 펜웍이라고 알고 있었던 때도 있었다. 워윅이나 워웍이나..
펜윅 트리를 나타내는 길이 N = 8의 예시가 아래 그림이다. 각 상자마다 시작 구간 ~ 끝 구간이 표시되어 있는데 해당 상자에는 구간의 합이 저장된다. 또한, 끝 구간이 해당 박스의 배열 인덱스이다. (1번 인덱스부터 사용한다고 하자. 구현할 때 1부터 써야 하는 이유가 나온다.)

N = 8 인 상황에서의 펜윅 트리 예시

5~6 이면 배열의 6번 인덱스의 값이라는 것이고 1차원 배열로 쭉 이어서 표현하면 아래 그림과 같다.

실제로는 이렇게 저장됩니다

다시 돌아와서 펜윅 트리에서 만약 3 ~ 6 누적합을 구하고 싶다 한다면 어떻게 구할 수 있을까? 이전에 우리가 알고 있던 누적 합에서의 구간 합을 구할 땐 (1~6 구간의 합) - (1~2 구간의 합)으로 구했을 것이다. 펜윅도 똑같다. 대충 아래 그림을 봐보자. 배열의 이름이 fw이라면 fw[4] + fw[6] - fw[2] 를 하면 구간 합을 구할 수 있다. 

펜윅 트리에서의 3~6 구간합을 보여주는 예시

여기서 펜윅 트리가 BIT로 불리는 이유가 등장한다. binary의 이름에 걸맞게 우리가 구하고자 하는 구간의 끝 부분을 이진수로 나타내보자

6 = 110
2 = 010

솔직히 규칙이 잘 안 보인다. 그래서 좀 더 큰 걸 가져왔다. 8~14 구간의 합을 구하는 예시이다. (1~14 구간의 합) - (1~7 구간의 합) 

1~14 구간의 합은 fw[14] + fw[12] + fw[8] 이며, 1~7 구간의 합은 fw[7] + fw[6] + fw[4] 이다.
규칙이 안 보이면 어쩔 수 없다.
14를 이진수로 나타냈을 때, 1110이며 구간의 합을 의미했던 14, 12, 8은 14의 이진수에서 가장 오른쪽 1이 사라지는 과정이 된다. 이게 언제까지? 해당 숫자가 0이 될 때까지. 그래서 배열의 인덱스가 1부터 시작했던 것이다. 
다시 표현해 보면 14 = 1110 , 12 = 1100, 8 = 1000이며, 8에서 가장 오른쪽 1을 지우면 0000 이 되므로 더 이상 더해줄 구간 합이 없다는 것을 의미한다.
어쩌다 보니 구간 합을 구하는 코드를 먼저 설명하게 되었다. 암튼 우리가 구하고자 하는 1~x 구간의 x를 위와 같이 가장 마지막 비트 1을 지워주면서 0이 될 때까지 반복하면 구할 수 있다. 가장 마지막 1 비트는 어떻게 구할까?
비트 마스크때 했던 것처럼 아래와 같이 시프트 연산자로도 구할 수 있다. 근데 더 쉬운 방법이 있다.

int idx = -1;
for (int i = 0; i < 31; i++) {
	if ((1 << i) & X) {
		idx = i;
		break;
	}
}
X = X - (1 << idx);

우리가 처음 프로그래밍을 공부하면 음수를 표현하는 방법을 배웠었다. 2의 보수법 (링크 two's complement)라고 했던 것 같다.

+14 = 00001110
-14 = 11110001 + 00000001 = 11110010

14와 -14를 자세히 보니까 모든 비트가 다른데 가장 마지막 1의 비트만 같은 것을 볼 수 있다. 이걸 이용하면 X와 -X를 AND 연산자 (&)를 취해주면, 시프트 하면서 비교를 안 해도 바로 알 수 있다. 홀리 쒯..

X -= (X & -X);

구간 합은 이렇게 구할 수 있겠다..

int get_sum(int X) {
	int sum = 0;
	while ( X > 0 ) {
		sum += fw[X];
		X -= (X & -X);
	}
	return sum;
}

그럼 이제 펜윅 트리에 구간 합을 구할 수 있도록 누적 합을 저장하는 과정만 알면 된다. 다시 리마인드 하는 것이지만, 펜윅 트리는 중간중간에 원소의 값이 바뀔 때도 구간 합을 빠르게 구하는 목적으로 사용하는 것이다.
우리가 1번 인덱스의 값을 수정할 때, 영향을 받는 인덱스 번호는 몇 번이 있을까? 아래 그림을 보자!

1번 위치의 값이 변경되었을 때, 영향을 받는 위치.

1번을 수정했더니, 2, 4, 8 이 바뀌었다. 이것도 한번 이진수로 나타내보자.

1 = 0001
2 = 0010
4 = 0100
8 = 1000

어? 뭔가 2배씩 늘어나는 것 같다. 이번엔 5번도 수정해 보자.

5번 위치의 값이 변경되었을 때, 영향을 받는 위치.

5번을 수정했더니, 5, 6, 8 이 바뀌었다.

5 = 0101
6 = 0110 = 0101 + 0001
8 = 1000 = 0110 + 0010

이번엔 2배는 아니지만 뭔가 더해지는 것 같다. 구간 합을 구할 땐, 가장 마지막 비트를 지워가면서 X를 0으로 만들었다. 누적 합을 수정할 땐 가장 마지막 비트를 더해가면서 X가 N 보다 커지면 업데이트를 중단하면 된다. v는 현재 X 번째 값과 새로 넣고자 하는 값의 차이 (변화량)를 넘겨주면 된다.

void update(int X, int v) {
	while (X <= N) {
		fw[X] += v;
		X += (X & -X);
	}
}

정리해 보면, 값을 수정할 때는 수정한 인덱스 X로부터 X가 증가하는 방향으로 인덱스를 이동해 주면 되며 구간의 합을 구할 땐 구하고자하는 구간의 끝 인덱스 X로 부터 X가 감소하는 방향으로 인덱스를 이동해주면 된다. 그림으로 표현하면 아래와 같다.

구간 합을 구하는 이동방향 (왼쪽그림), 값을 수정하는 이동방향 (오른쪽 그림)

내가 만약 길이 4 짜리 배열 [1, 5, 3, 9]를 펜윅 트리로 나타내고 싶다면 두 개의 배열만 잡아주면 된다. 하나는 원래 값을 저장하고 있는 배열 A, 다른 하나는 펜윅 트리를 표현하는 배열 fw.
중간중간 값이 변할 때, 배열에서 바로 현재 값과 바꾸고자 하는 값의 차이를 O(1) 복잡도로 구할 수 있기 때문에. (아래처럼)

int new_value;
int cur_value = A[X];
int dif = A[X] - new_value;
update(X, dif);
A[X] = new_value;

https://www.acmicpc.net/problem/1280

 

1280번: 나무 심기

첫째 줄에 나무의 개수 N (2 ≤ N ≤ 200,000)이 주어진다. 둘째 줄부터 N개의 줄에 1번 나무의 좌표부터 차례대로 주어진다. 각각의 좌표는 200,000보다 작은 자연수 또는 0이다.

www.acmicpc.net

다음 포스팅에서는 펜윅트리를 가지고 1280번 나무 심기 문제를 풀어보도록 하겠다.

아래는 정답코드

더보기
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;

#define MOD (1000000007LL)
int _max = 200001;
int n;
ll t[200001];
ll fw[200002][2]; // x 좌표에서의 개수, 1~x 좌표까지의 합

void update(int x, int v) {
	while (x <= _max) {
		fw[x][0] += v;
		fw[x][1] += 1;
		x += (x & -x);
	}
}
pll query(int x) { // 좌표의 합과 개수를 반환
	pll ret = { 0,0 };
	while (x > 0) {
		ret.first += fw[x][0];
		ret.second += fw[x][1];
		x -= (x & -x);
	}
	return ret;

}


void solve() {
	cin >> n;
	for (int i = 1; i <= n; i++) {
		cin >> t[i];
		t[i] += 1; // 1 base  하기 위해.
	}
	ll ans = 1;
	update(t[1], t[1]);
	for (int i = 2; i <= n; i++) {
		pll ret1 = query(t[i]);
		pll ret2 = query(200001);
		ll ret = ret2.first - ret1.first - (ret2.second - ret1.second) * t[i];
		ret += t[i] * ret1.second - ret1.first;
		ret %= MOD;
		ans *= ret;
		ans %= MOD;
		update(t[i], t[i]);
	}
	cout << ans;
}


int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	int tc = 1;
	//cin >> tc;
	for (; tc--;) {
		solve();
	}
}
반응형