PS

Xor sum - Atcoder

jyheo98 2021. 1. 23. 10:04

문제 링크 : atcoder.jp/contests/abc050/tasks/arc066_b

 

D - Xor Sum

AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.

atcoder.jp

관찰)

 

A + B = A ^ B + 2 ( A & B )

A | B + A & B = A + B 

 

-> A | B = (u + v) / 2, A & B = (u - v) / 2 이다.

-> A | B + A & B <= N인 A, B 쌍을 찾으면 되는 문제이다.

-> x + y <= N, x OR y = x인 x, y 쌍을 찾으면 되는 문제이다.

 

dp 설계)

 

비트 dp를 할건데, 맨 아랫자리부터 할것이다.

변수 세개가 있다. 

digit - 현재 몇번째 자리를 보고 있는지

flowBit - 전 자릿수에서 1이 넘어왔는지 여부

under - 현재 자리~끝 자리까지만 봤을 때 N을 초과했는지의 여부

 

각 digit마다 가능한 경우로 (x,y) = (0,0) (1,0) (1,1)이 있다. (0,1)은 x|y!=x라 안된다.

이 세가지 경우에 대해

(현재 자릿수로 넘어온 flowbit) + x + y로 현재 자리의 bit을 계산해주고

다음 자릿수로 넘어갈 flowbit도 계산해준다.

 

현재 자리의 bit이 N의 현재 자리 bit과 같으면 under 변수는 그대로 간다. (원래 넘었던 상태면 계속 넘은 상태, 원래 안남었던 상태면 계속 안 넘었던 상태)

 

현재 자리의 bit이 0인데, N의 현재자리 bit이 1이면 under은 true, 반대 상황은 under은 false가 될 것이다.

 

이걸 가장 높은 62번째 자리까지 반복하면, N은 1e18이 한계이기 때문에

무조건 under = true, 넘어온 flowbit이 0인 경우만 허용이 될 것이다.

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <bits/stdc++.h>
using namespace std;
 
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2")
 
#define IOS ios::sync_with_stdio(false);cin.tie(0)
#define all(x) x.begin(), x.end()
#define ff first
#define ss second
#define LLINF 0x3f3f3f3f3f3f3f3f
#define INF 0x3f3f3f3f
#define uniq(x) sort(all(x)); x.resize(unique(all(x))-x.begin());
#define sz(x) (int)x.size()
#define pw(x) (1LL<<x)
 
using pii = pair<intint>;
using ll = long long;
const ll MOD = 1e9 + 7;
const long double PI = acos(-1.0);
 
void solve() {
    ll N; cin >> N;
 
    vector<vector<vector<ll>>> dp(64vector<vector<ll>>(2vector<ll>(2-1)));
 
    function<ll(int,int,int)> dfs = [&](int digit, int flowBit, int under) {
        ll& ret = dp[digit][flowBit][under];
        if(ret != -1)
            return ret;
        ret = 0;
        int curBit = (N >> digit) & 1;
        if(digit == 62) {
            if(flowBit == 0 && under == true
                return ret = 1LL;
            else
                return ret = 0LL;
        }
        for(int bit1 = 0 ; bit1 <= 1 ; bit1 ++) {
            for(int bit2 = 0 ; bit2 <= 1 ; bit2 ++) {
                if(bit2 > bit1) 
                    continue;
                int nxtBit = (bit1 + bit2 + flowBit) >= 2;
                int calcedCurBit = (bit1 + bit2 + flowBit) & 1;
                int newUnder = under;
                if(calcedCurBit < curBit)
                    newUnder = 1;
                else if(calcedCurBit > curBit) 
                    newUnder = 0;
                ret += dfs(digit + 1, nxtBit, newUnder);
                ret %= MOD;
            }
        }
        return ret;
    };
 
    ll ans = dfs(001);
    cout << ans << "\n";
}
 
int main() {
    IOS;
    int t; t = 1;
    while(t--)
        solve();
}
cs