Lowest Common Ancestor

LCA

N(2 ≤ N ≤ 100,000)개의 정점으로 이루어진 트리가 주어진다. 트리의 각 정점은 1번부터 N번까지 번호가 매겨져 있으며, 루트는 1번이다. 두 노드의 쌍 M(1 ≤ M ≤ 100,000)개가 주어졌을 때, 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력

LCA는 두 노드 간의 경로 길이 나 촌수 계산 등에 활용

Native 구현

dfs(int curr, int depth) {
    depths[curr] = depth;
    visited[curr] = 1;
    
    for (int next : adj.get(curr)) {
        
        if (visited[next] == 0) {
            dfs(next, depth + 1);
            parents[next] = curr;
            
        }
    }
}

depth와 parent 를 맞춰준다

int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());

if (depths[u] < depths[v]) {
    int t = u;
    u = v;
    v = t;
}
while (depths[u] != depths[v]) {
    u = parents[u];
}
while (u != v) {
    u = parents[u];
    v = parents[v];
}
System.out.println(u);

DP 구현

tree가 depth를 맞추는 것과 parent를 맞추는 과정을 빠르게 해보자

parents[u][k] 정점 u의 2^k 번째 부모 정의

parents[u][k] = parent[parent[u][k-1]][k-1]

depth diff가 3 이라면 이진 수로 011(2) 이고 parents[parents[u][0]][1] 으로 u를 맞춰준다.

int dist = depths[u] - depths[v];
int cnt = 0;
while (dist > 0) {
    if (dist % 2 == 1)
        u = parents[u][cnt];
    dist /= 2; // or dist >>= 1
    cnt++;
}

다른 방법으로는, 최대 MAX 로 부터 모두 체크해본다

for (int j = MAX; j >= 0; j--) {
    if (depths[v] <= depths[parents[u][j]]) {
        u = parents[u][j];
    }
}

depth가 같으면 이제 parent를 맞춘다

parents가 같다면

if (u != v) {
    for (int j = MAX; j >= 0; j--) {
        if (parents[u][j] != parents[v][j]) {
            u = parents[u][j];
            v = parents[v][j];
        }
    }
    u = parents[u][0];
}

구간 트리로 구현

  1. tree path를 만든다 dfs visited count 로 경로를 만들고 back할 때도 추가해야한다. 1->2->3->2->4 ….

  2. node가 표현된 tree path에서 몇번째 node 인가 늘어나는 path 크기를 담으면 된다

  3. query로 얻은 최소값은 (트리에서의 몇번째 node) 실제 어떤 노드인가? 1번에서의 visited count 에 해당하는 node를 저장하면됨

전위 순회를 통해 노드에 새로운 번호를 부여하고 구간트리에 이용할 1차원 배열을 만듦

void traverse(int here, int d, vector<int> &trip)
{
    no2serial[here] = nextSerial;
    serial2no[nextSerial] = here;
    nextSerial++;
    visitied[here] = true;
 
    depth[here] = d;
    locInTrip[here] = trip.size();
    trip.push_back(no2serial[here]);
 
    for(int i=0; i<graph[here].size(); i++){
        if(!visitied[graph[here][i]]){
          traverse(graph[here][i], d+1, trip);
          trip.push_back(no2serial[here]);
        }
    }
}
 
int LCA(int u, int v)
{
    int lu = locInTrip[u];
    int lv = locInTrip[v];
 
    if(lu > lv) swap(lu, lv);
    return serial2no[query(lu,lv,1,0,2*N-2)];
}