https://www.acmicpc.net/problem/14287
[필자 사고]
이 문제는 세그먼트 트리 자료구조를 이용해서 풀어야 되는 문제이다.
다만 이 문제만의 재미난 점은 세그먼트 index를 설정해야 되는데 그게 dfs탐색을 이용한 오일러 알고리즘을 통해
index들을 새로 정의 해야 된다.
또한 새로 정의된 index를 바탕으로 월급또한 같이 재정의 해야된다는 점이다.
DFS알고리즘을 잘 수행했따면 Propagation 으로 느린전파 알고리즘을 이용하여 Update의 범위 쿼리를 갱신해주면 문제를 풀 수 있게 된다.
여기서 만약 1-based로 시작하면 index[a].first 가 의미하는 거는 a의 실제 index자신을 의미하므로
a보다 부하를 원하면 index[a].first +1 을 해야된다는 점이다.
이 문제에서 0-based로 했기 때문에 아래와 같이 작성했다.
[코드 해설]
2. 입력 처리
프로그램은 다음과 같은 입력을 처리합니다:
- n (노드 개수)와 m (쿼리 개수)을 입력받습니다.
- 루트 노드의 부모 정보를 입력받고, 이후 각 노드의 부모 정보를 읽어 자식 노드 목록을 생성합니다.
입력 예시
복사편집
5 3 0 1 1 2
- 5개의 노드가 존재하며, 루트 노드는 0입니다.
- 각 노드의 부모 정보는 다음과 같이 주어집니다:
- 1번 노드의 부모는 0번 노드
- 2번 노드의 부모는 0번 노드
- 3번 노드의 부모는 1번 노드
- 4번 노드의 부모는 1번 노드
3. DFS 함수
dfs(int x)
- 노드 x를 기준으로 DFS를 수행하며, 트리의 각 노드에 대해 서브트리 범위를 계산합니다.
- 작동 방식:
- 현재 노드의 시작 인덱스를 기록 (dat[x].first = cnt).
- 자식 노드를 순회하며 DFS를 재귀 호출.
- 모든 자식 노드를 탐색한 후, 현재 노드의 끝 인덱스를 기록 (dat[x].second = cnt - 1).
dat[x] 의미
- dat[x].first: 노드 x의 서브트리 시작 인덱스.
- dat[x].second: 노드 x의 서브트리 끝 인덱스.
4. 세그먼트 트리 연산
세그먼트 트리는 노드 값 업데이트와 범위 합 쿼리를 빠르게 수행하기 위해 사용됩니다.
4.1 update(int x, int s, int e, int idx, long long v)
- 세그먼트 트리에서 특정 위치 idx에 값을 v만큼 추가합니다.
- 매개변수:
- x: 현재 세그먼트 트리 노드 번호.
- s, e: 현재 노드가 담당하는 범위.
- idx: 값을 추가할 인덱스.
- v: 추가할 값.
- 작동 방식:
- 업데이트 범위에 포함되지 않으면 그대로 반환.
- 리프 노드에 도달하면 값을 추가.
- 범위를 나눠 좌우 자식 노드를 재귀적으로 업데이트.
4.2 query(int x, int s, int e, int l, int r)
- 특정 범위 [l, r]의 값을 합산합니다.
- 매개변수:
- x: 현재 세그먼트 트리 노드 번호.
- s, e: 현재 노드가 담당하는 범위.
- l, r: 합을 구할 범위.
- 작동 방식:
- 범위가 겹치지 않으면 0 반환.
- 범위가 완전히 포함되면 해당 노드 값을 반환.
- 좌우 자식 노드로 나눠 재귀 호출.
5. 주요 로직 (메인 함수)
초기화
- dfs(0) 호출: 루트 노드부터 서브트리 범위를 계산.
- dat 배열에 각 노드의 서브트리 범위 저장.
쿼리 처리
- 총 m개의 쿼리를 처리하며, 두 가지 명령을 지원합니다:
- 명령 1: 1 x v
- 노드 x의 서브트리에 값 v를 추가.
- update(1, 0, n-1, dat[x].first, v) 호출.
- 명령 2: 2 x
- 노드 x의 서브트리에 저장된 값의 합을 계산.
- query(1, 0, n-1, dat[x].first, dat[x].second) 호출.
- 명령 1: 1 x v
[소스 코드]
#include <iostream>
#include <vector>
using namespace std;
int n, m, cnt;
vector<int> v[100001];
pair<int, int> dat[100001];
long long seg[100001 * 4];
void dfs(int x) {
dat[x].first = cnt++;
for (int i = 0; i < v[x].size(); i++) {
int y = v[x][i];
dfs(y);
}
dat[x].second = cnt - 1;
}
long long update(int x, int s, int e, int idx, long long v) {
if (idx < s || e < idx) return seg[x];
if (s == e) {
return seg[x] += v;
}
int mid = s + (e - s) / 2;
return seg[x] = update(x * 2, s, mid, idx, v) + update(x * 2 + 1, mid + 1, e, idx, v);
}
long long query(int x, int s, int e, int l, int r) {
if (r < s || e < l) return 0;
if (l <= s && e <= r) return seg[x];
int mid = s + (e - s) / 2;
return query(x * 2, s, mid, l, r) + query(x * 2 + 1, mid + 1, e, l, r);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n >> m;
int p;
cin >> p;
for (int i = 1; i < n; i++) {
cin >> p;
v[p - 1].push_back(i);
}
dfs(0);
for (int i = 0; i < m; i++) {
int cmd;
cin >> cmd;
if (cmd == 1) {
int x;
long long v;
cin >> x >> v;
x--;
update(1, 0, n - 1, dat[x].first, v);
}
else {
int x;
cin >> x;
x--;
cout << query(1, 0, n - 1, dat[x].first, dat[x].second) << '\n';
}
}
return 0;
}