ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • APIO 2021 후기+C풀이
    APIO 2021. 5. 27. 21:06

    마지막 APIO가 끝났습니다~~. 생각보다 못해서 별 기대 안하고 있었지만 상을 받게 되어서 놀랐는데, 점수는 9/37/100으로 다른 참가자들이 많이 받은 B문제의 81점도 못받아서 아쉽게 끝난 감이 있습니다. 또 A의 20점을 '끝나기 직전에 긁어야지~'라는 안일한 생각으로 놓친 것도 반성해야 한다고 생각합니다.

     

    대신 C에서 100점을 획득했는데, 정해로 보이는 Small-to-Large를 쓰지 않고 완전히 다른 방법으로 AC를 맞았습니다. 개인적으로 이 방법이 더 직관적이고 코딩도 간단하다고 생각하여 풀이를 소개해보고자 합니다.

     

    문제 설명

     

    N(N<=10^5) 크기의 간선에 가중치가 부여된 트리가 주어진다. 이때 모든 k(0<=k<N)에 대해, 모든 정점에 대해 최대 k개의 간선만이 살아있도록 간선들을 지울 때 지우는 간선들의 가중치를 최소화하여라.

     

    Subtask 1 - 5 pts

     

    성게 모양의 트리이다. 모든 간선의 가중치를 정렬한 뒤 적절히 선택해주면 된다.

     

    Subtask 2 - 7 pts

     

    직선 모양의 트리이다. k>2일때는 답이 당연히 0일테니 k=0과 k=1인 case만 고려하면 된다. k=0일때의 답은 모든 간선의 가중치의 합일 것이고, k=1일때의 답은 굉장히 trivial한 DP를 통해 해결할 수 있다. DP[i][j]를 1~i번 까지의 정점에 조건을 만족시키도록 최적의 방법으로 간선을 지우며 j=0이면 i와 i+1을 잇는 간선이 살아있고, j=1이면 그 간선이 지워졌을 때의 최솟값이라 정의하자. DP[i][0]=DP[i-1][1]이며, DP[i][1]=min(DP[i-1][0], DP[i-1][1])+(i와 i+1을 잇는 간선의 가중치)이다.

     

    Subtask 3&4 - 38 pts

     

    Subtask 2의 풀이를 트리로 옮기면 된다. 각각의 k에 대해 일일이 계산하면 된다.

     

    DP[i][j]의 기존 정의를 (1~i번 정점)을 (i번 정점의 서브트리)로, (i와 i+1을 잇는 간선)을 (i와 그의 부모노드를 잇는 간선)로 변경한다. 또 DP의 상태전이를 모든 일반적인 k로 확장해야 한다.

     

    정점 v와 그 자식노드(c1, c2, ..., cm) 사이의 간선을 모두 지우고, 거기서 일부 간선을 다시 추가하는 상황을 생각해보자. 모두 지운 값은 dp[c1][1]+...+dp[cm][1]이고, ci와 연결된 간선을 다시 살리는 비용은 dp[ci][0]-dp[ci][1]이다. 이 값이 0보다 크면 오히려 해당 간선을 살리는 것이 손해라는 것이므로 이러한 값들을 제외하고 생각하자. dp[v][0]의 경우에는 dp[ci][0]-dp[ci][1]들 중 가장 작은 k-1개를, dp[v][1]은 가장 작은 k개를 dp[c1][1]+...+dp[cm][1]에 더한 값이 된다. 물론 dp[v][1]의 경우에는 v와 그 부모를 잇는 간선의 가중치를 따로 더해줘야 한다.

     

    각 정점에 대해 dp[ci][0]-dp[ci][1]들을 정렬하므로 각 step의 시간복잡도는 O(NlgN)이고, 총 시간복잡도는 O(N^2lgN)이다.

     

    Full Solution - 100 pts

     

    생각해보면, deg(v)<=k를 만족하는 v들은 조건에서 자유롭다는 것을 알 수 있다. 트리에서 deg(1)+deg(2)+...+deg(n)=2n-2 이므로, Subtask 3&4에서 설명한 연산을 deg(v)>k인 정점에서만 해줄 수 있다면 전체 문제를 O(NlgN)에 해결할 수 있다는 사실을 이용하여 문제를 풀어보도록 하겠다.

     

    필요없는 정점들을 생략하기 위해 deg(v)>k인 정점들만을 활용하여 트리를 압축하도록 하겠다. 트리압축은 다음과 같은 과정을 통해 진행된다.

     

    1) 고려할 정점 M개를 euler tour 순서를 기준으로 정렬한다. (해당 배열을 arr이라 하자.)

    2) lca(arr[i], arr[i+1]) (i<M)을 arr의 뒤에 추가한다.

    3) 다시 전체 정점을 euler tour순으로 정렬하고, unique함수 등을 이용해 동일한 정점들을 제거한다.

    4) stack을 사용하여 arr에 남은 정점들 사이에 간선을 적절히 이어준다. (euler tour의 과정을 생각하면서 하면 쉽다.)

     

    위 과정을 걸쳐서 원하는 정점이 M개라면, 2M-1개의 유의미한 정점들만을 활용하여 트리 상에서의 거리관계를 유지하면서 트리를 압축할 수 있다.

     

    간선의 가중치는 parent와 인접한 가중치로 설정한다. 즉, 1-2-3-4 의 chain을 1-4로 압축하였을때의 가중치는 기존 트리에서 1-2의 가중치로 한다.

     

    이제 dp를 해주어야 하는데, 위의 Subtask들과 달리 고려해야 할 사항이 몇가지 더 있다.

     

    1) 압축된 트리에서는 부모-자식 관계지만 실제 트리에서는 부모-자식 관계가 아닌 경우

    실제 트리에서는 부모-자손 관계일 것이다. 여기서 달라지는 것은 부모입장에서 전이할 때 보이는 dp[ci][0]과 dp[ci][1]의 값이다. dp의 정의를 따르면 위 두개의 값은 실제 트리에서 ci와 그의 부모를 잇는 간선의 삭제 유무이지, 압축된 트리에서 ci와 그의 부모를 잇는 간선의 삭제 유무가 아니다. 따라서 이 경우, dp[ci][0]의 값을 실제로는 min(dp[ci][0], dp[ci][1]), dp[ci][1]의 값을 min(dp[ci][0], dp[ci][1])+(압축된 트리에서 v와 ci를 잇는 간선의 가중치)로 보아야 한다.

     

    2) 압축된 트리에서는 사라졌지만, dp 전이에 필요한 간선들

    각 정점에서 자신의 자식 정점들의 dp[ci][0]-dp[ci][1]값을 관리해야 하는데, 압축된 트리에서는 자신의 원래 자손들의 정보를 전부 저장할 수 없다. 하지만 이 경우 압축된 트리에서 보이지 않는 정점들은 실제로 그 subtree의 모든 정점들의 deg 값이 전부 k이하라는 사실을 다시 생각해보자. 즉, dp[ci][0]=dp[ci][1]=0인 것이다.

     

    1번과 2번 예외 케이스를 동시에 전부 해결하는 방법이 있다. k의 값을 0부터 N-1까지 늘려가면서 어떤 정점 v가 deg(v)<=k가 되어 의미없어지는 경우, 이를 'v를 무효화한다'라고 하겠다. 모든 정점은 최대 한번 무효화된다. 1번과 2번 예외케이스 모두 어떤 정점 v가 무효화되었을 때 생기는 경우이며, 우리가 실제로 필요한 dp[ci][0]-dp[ci][1] 값은 어떤 자식정점 ci가 무효화되었을 때 그 이후 항상 일정하다는 사실을 관찰하자. 심지어 그때의 값은 기존 트리에서 v와 ci를 잇는 간선의 가중치*(-1)이다.

     

    각각의 정점마다 자식 정점의 개수만큼의 크기를 갖는 Segment Tree를 구축하자. 해당 Segment Tree는 숫자의 범위가 [a, b]안에 있으면서, 최대 k개의 원소만을 사용한 sum의 최솟값을 계산할 수 있어야 한다. 정점이 무효화 되었을 때 그 부모정점의 Segment Tree에 간선의 가중치*(-1) 값을 insert한다. 전이 단계에서는 temp라는 값에 압축된 트리에서 부모-자식 관계이면서, 실제 트리에서도 부모-자식 관계인 쌍들에 대해서는 dp[ci][1]을 더해주고,  압축된 트리에서 부모-자식 관계이지만 트리에서는 그렇지 않은 쌍에 대해서는 min(dp[ci][0], dp[ci][1])를 전부 더해준다. 첫번째 경우에서 나온 dp[ci][0]-dp[ci][1]들과, v에 있는 Segment Tree에서의 쿼리들을 적절히 배합하면 시간복잡도를 벗어나지 않고 문제를 해결할 수 있다.

     

    코드는 다음과 같습니다.

     

    #include <bits/stdc++.h>
    #define mp make_pair
    #define eb emplace_back
    #define F first
    #define S second
    #define all(x) x.begin(), x.end()
    #define svec(x) sort(all(x))
    #define press(x) x.erase(unique(all(x)), x.end());
    using namespace std;
    typedef long long LL;
    typedef pair<int, int> pii;
    typedef pair<int, LL> pil;
    typedef pair<LL, int> pli;
    typedef pair<LL, LL> pll;
    const int INF=1e9;
    const LL LLINF=1e18;
    
    int n, par[100010], num[100010], sp[100010][30], d[100010], re;
    pii eul[100010];
    vector<pil> link[100010];
    LL dp[100010][2], pedge[100010], sum2[100010];
    bool ch[100010];
    
    int lca(int a, int b){
        if(d[a]<d[b])swap(a, b);
        for(int i=20; i>=0; i--){
            if(d[a]-(1<<i)>=d[b])a=sp[a][i];
        }
        if(a==b)return a;
        for(int i=20; i>=0; i--){
            if(sp[a][i]!=sp[b][i])a=sp[a][i], b=sp[b][i];
        }
        return sp[a][0];
    }
    
    pli operator+(pli a, pli b){return mp(a.F+b.F, a.S+b.S);}
    struct DYNAMIC_MINCOWSKI{
        vector<pli> vc, seg;
        void init(){
            seg.resize(vc.size()*4);
            svec(vc);
            for(int i=0; i<vc.size(); i++)num[vc[i].S]=i;
        }
        void upd(int point, int s, int e, int num){
            if(s==e){
                seg[point]=mp(vc[s].F, 1);
                return;
            }
            if(num<=(s+e)/2)upd(point*2, s, (s+e)/2, num);
            else upd(point*2+1, (s+e)/2+1, e, num);
            seg[point]=seg[point*2]+seg[point*2+1];
        }
        void upd(int num){upd(1, 0, vc.size()-1, num);}
        pli query(int point, int s, int e, int a, int b, int cut){
            if(e<a||s>b||cut<=0)return mp(0ll, 0);
            if(a<=s&&e<=b&&seg[point].S<=cut)return seg[point];
            pli tmp=query(point*2, s, (s+e)/2, a, b, cut);
            pli tmp2=query(point*2+1, (s+e)/2+1, e, a, b, cut-tmp.S);
            return tmp+tmp2;
        }
        pli query(int a, int b, int cut){return query(1, 0, vc.size()-1, a, b, cut);}
        int lb(LL c){
            return lower_bound(all(vc), mp(c, INF))-vc.begin();
        }
    }mink[100010];
    
    vector<int> deg[100010];
    
    void pdfs(int num, int p){
        d[num]=d[p]+1;
        par[num]=sp[num][0]=p;
        eul[num].F=++re;
        for(auto i:link[num]){
            if(i.F==p)continue;
            mink[num].vc.eb(-i.S, i.F);
            pedge[i.F]=i.S;
            pdfs(i.F, num);
        }
        eul[num].S=re;
        mink[num].init();
    }
    
    vector<int> prs[100010];
    
    void dfs(int now, int k){
        LL sum=sum2[now];
        vector<LL> vc;
        for(auto i:prs[now]){
            dfs(i, k);
            if(d[i]-d[now]==1&&ch[i]==false){
                sum+=dp[i][1];
                if(dp[i][0]<dp[i][1])vc.eb(dp[i][0]-dp[i][1]);
            }
            else sum+=min(dp[i][0], dp[i][1]);
        }
        svec(vc);
        dp[now][0]=sum;
        dp[now][1]=sum+pedge[now];
        pli tmp=mp(0ll, 0);
        int cut=0;
        for(auto i:vc){
            int ncut=mink[now].lb(i);
            tmp=tmp+mink[now].query(cut, ncut-1, k-1-tmp.S);
            cut=ncut;
            if(tmp.S>=k-1)break;
            tmp.S++;
            tmp.F+=i;
        }
        tmp=tmp+mink[now].query(cut, INF, k-1-tmp.S);
        dp[now][0]+=tmp.F;
    
        tmp=mp(0ll, 0);
        cut=0;
        for(auto i:vc){
            int ncut=mink[now].lb(i);
            tmp=tmp+mink[now].query(cut, ncut-1, k-tmp.S);
            cut=ncut;
            if(tmp.S>=k)break;
            tmp.S++;
            tmp.F+=i;
        }
        tmp=tmp+mink[now].query(cut, INF, k-tmp.S);
        dp[now][1]+=tmp.F;
    }
    
    set<int> s;
    bool cmp(int a, int b){return eul[a]<eul[b];}
    
    vector<LL> minimum_closure_costs(int N, vector<int> U, vector<int> V, vector<int> W){
        n=N;
        vector<LL> ret(n);
        for(int i=0; i<n-1; i++){
            link[U[i]+1].eb(V[i]+1, W[i]);
            link[V[i]+1].eb(U[i]+1, W[i]);
        }
    
        for(int i=1; i<=n; i++){
            deg[link[i].size()].eb(i);
            s.insert(i);
        }
        pdfs(1, 0);
        for(int j=1; j<=20; j++){
            for(int i=1; i<=n; i++)sp[i][j]=sp[sp[i][j-1]][j-1];
        }
    
        for(int i=0; i<n; i++){
            for(auto j:deg[i]){
                s.erase(j);
                ch[j]=true;
                if(!par[j])continue;
                sum2[par[j]]+=pedge[j];
                mink[par[j]].upd(num[j]);
            }
    
            vector<int> vc, stk;
            for(auto j:s)vc.eb(j);
            for(int j=1; j<s.size(); j++)vc.eb(lca(vc[j-1], vc[j]));
            vc.eb(1);
            sort(all(vc), cmp); press(vc);
    
            stk.eb(vc[0]);
            for(int j=1; j<vc.size(); j++){
                while(1){
                    if(eul[stk.back()].S>=eul[vc[j]].F)break;
                    stk.pop_back();
                }
                prs[stk.back()].eb(vc[j]);
                stk.eb(vc[j]);
            }
            dfs(1, i);
            ret[i]=dp[1][1];
            for(auto j:vc)prs[j].clear();
        }
    
        return ret;
    }
    

     

     

Designed by Tistory.