algorithm/SegmentTree

boj 10167 금광

와이에쓰 2022. 9. 2. 15:47

https://www.acmicpc.net/problem/10167

 

10167번: 금광

첫 줄에는 금광들의 개수 N (1 ≤ N ≤ 3,000)이 주어진다. 이어지는 N개의 줄 각각에는 금광의 좌표 (x, y)를 나타내는 음이 아닌 두 정수 x와 y(0 ≤ x, y ≤ 109), 그리고 금광을 개발하면 얻게 되는 이

www.acmicpc.net

이 문제를 풀기 전에, [boj 16993 연속합과 쿼리] 문제를 먼저 푸시는 것을 추천드립니다. [풀이]

 

금광을 풀기 위해 금광 세그가 등장하여 많이 알려진 것을 보면 얼마나 어려운 문제인지 예상이 될 것 같습니다.

금광세그 + 좌표압축 + 스위핑으로 많은 테크닉이 이 하나의 문제를 풀기 위해 필요합니다.

 

전체 흐름은

  1. 좌표압축으로 좌표 범위를 줄인 후에
  2. y 좌표를 기준으로 x 좌표들을 모으고
  3. y 좌표 범위를 이동시키며, 연속합의 최대를 구한다.

이렇게 되겠습니다.

 

좌표 압축

좌표를 인덱스로 하여 트리를 구성합니다.

좌표압축 설명은 [여기]에 포스팅되어 있습니다.

x, y 좌표 모두 압축 진행해주시면 됩니다.

y 좌표를 기준으로 x 좌표를 모을 예정이므로 다시 재정렬할 필요는 없습니다.

 

스위핑

먼저, y 좌표를 인덱스로 같은 y 좌표를 가지는 x 좌표를 벡터에 모읍니다.

yList[1] = {1, 2, 3} // y = 1인 x 좌표가 1, 2, 3

그 다음 y 좌표 범위를 이동하면서 최적의 답을 구합니다.

y 좌표의 범위는 1 ~ yCnt, 2 ~ yCnt, ..., y ~ yCnt에 존재하는 x 좌표들 대상입니다.

범위가 고정되었으니 이제 x 좌표들로 연속합의 최대를 구합니다.

 

금광세그

여기부터는 이제 연속합과 쿼리 문제와 동일합니다.

연속합의 최대를 구하기 위해 세그먼트 트리를 이용합니다.\

 

트리에는 아래 값들을 저장하게 됩니다.

  • l: leftNode의 최대
    • max(leftNode.l, leftNode.sum + rightNode.l)
  • r: rightNode의 최대
    • max(rightNode.r, rightNode.sum + leftNode.r)
  • max: 현재 노드의 최대
    • max(leftNode.max, rightNode.max, leftNode.r + rightNode.l)
  • sum: 현재 노드의 전체 합
    • leftNode.sum + rightNode.sum

y 좌표의 범위가 바뀔 때마다 들어오는 x 좌표 값들이 다르므로 트리는 반드시 초기화 해야합니다.

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include <iostream>
#include <algorithm>
#include <vector>
#include <cstring>
#define MAX 3001
using namespace std;
typedef long long ll;
typedef pair<intint> pii;
typedef pair<int, ll> pil;
 
typedef struct Point {
    int x, y;
    ll w;
    int idx;
}Point;
 
typedef struct Node {
    ll l, r, max, sum;
}Node;
 
Point list[MAX];
Node tree[MAX * 4];
vector<pil> yList[MAX];
int N;
 
void update(int node, int s, int e, int idx, ll diff) {
    if (idx < s || idx > e) return;
    if (s == e) {
        tree[node].l += diff;
        tree[node].r += diff;
        tree[node].max += diff;
        tree[node].sum += diff;
        return;
    }
 
    int m = (s + e) / 2;
    update(node * 2, s, m, idx, diff);
    update(node * 2 + 1, m + 1, e, idx, diff);
    tree[node] = { max(tree[node * 2].l, tree[node * 2].sum + tree[node * 2 + 1].l), max(tree[node * 2 + 1].r, tree[node * 2 + 1].sum + tree[node * 2].r), max(max(tree[node * 2].max, tree[node * 2 + 1].max), tree[node * 2].r + tree[node * 2 + 1].l), tree[node * 2].sum + tree[node * 2 + 1].sum };
}
 
Node query(int node, int s, int e, int l, int r) {
    if (s > r || l > e) return { (int)-1e9, (int)-1e9, (int)-1e90 };
    if (l <= s && e <= r) return tree[node];
 
    int m = (s + e) / 2;
    Node leftNode = query(node * 2, s, m, l, r);
    Node rightNode = query(node * 2 + 1, m + 1, e, l, r);
    return { max(leftNode.l, leftNode.sum + rightNode.l), max(rightNode.r, rightNode.sum + leftNode.r), max(max(leftNode.max, rightNode.max), leftNode.r + rightNode.l), leftNode.sum + rightNode.sum };
}
 
void func() {
    sort(list + 1, list + 1 + N, [](Point a, Point b) {
        return a.x < b.x;
    });
 
    int pre = 1e9 + 1;
    int xCnt = 0;
    for (int i = 1; i <= N; i++) {
        if (pre != list[i].x) xCnt++;
        pre = list[i].x;
        list[i].x = xCnt;
    }
 
    sort(list + 1, list + 1 + N, [](Point a, Point b) {
        return a.y < b.y;
    });
 
    pre = 1e9 + 1;
    int yCnt = 0;
    for (int i = 1; i <= N; i++) {
        if (pre != list[i].y) yCnt++;
        pre = list[i].y;
        list[i].y = yCnt;
    }
 
    for (int i = 1; i <= N; i++) {
        yList[list[i].y].push_back({ list[i].x, list[i].w });
    }
 
    ll ans = -1e9 - 1;
    for (int i = 1; i <= yCnt; i++) {
        memset(tree, 0sizeof(tree));
        for (int j = i; j <= yCnt; j++) {
            for (int k = 0; k < yList[j].size(); k++) {
                update(11, N, yList[j][k].first, yList[j][k].second);
            }
            ans = max(ans, query(11, N, 1, N).max);
        }
    }
 
    cout << ans << '\n';
}
 
void input() {
    cin >> N;
    for (int i = 1; i <= N; i++) {
        cin >> list[i].x >> list[i].y >> list[i].w;
        list[i].idx = i;
    }
}
 
int main() {
    cin.tie(NULL); cout.tie(NULL);
    ios::sync_with_stdio(false);
 
    input();
    func();
 
    return 0;
}
cs