Xor sum - Atcoder
문제 링크 : atcoder.jp/contests/abc050/tasks/arc066_b
D - Xor Sum
AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.
atcoder.jp
관찰)
A + B = A ^ B + 2 ( A & B )
A | B + A & B = A + B
-> A | B = (u + v) / 2, A & B = (u - v) / 2 이다.
-> A | B + A & B <= N인 A, B 쌍을 찾으면 되는 문제이다.
-> x + y <= N, x OR y = x인 x, y 쌍을 찾으면 되는 문제이다.
dp 설계)
비트 dp를 할건데, 맨 아랫자리부터 할것이다.
변수 세개가 있다.
digit - 현재 몇번째 자리를 보고 있는지
flowBit - 전 자릿수에서 1이 넘어왔는지 여부
under - 현재 자리~끝 자리까지만 봤을 때 N을 초과했는지의 여부
각 digit마다 가능한 경우로 (x,y) = (0,0) (1,0) (1,1)이 있다. (0,1)은 x|y!=x라 안된다.
이 세가지 경우에 대해
(현재 자릿수로 넘어온 flowbit) + x + y로 현재 자리의 bit을 계산해주고
다음 자릿수로 넘어갈 flowbit도 계산해준다.
현재 자리의 bit이 N의 현재 자리 bit과 같으면 under 변수는 그대로 간다. (원래 넘었던 상태면 계속 넘은 상태, 원래 안남었던 상태면 계속 안 넘었던 상태)
현재 자리의 bit이 0인데, N의 현재자리 bit이 1이면 under은 true, 반대 상황은 under은 false가 될 것이다.
이걸 가장 높은 62번째 자리까지 반복하면, N은 1e18이 한계이기 때문에
무조건 under = true, 넘어온 flowbit이 0인 경우만 허용이 될 것이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2")
#define IOS ios::sync_with_stdio(false);cin.tie(0)
#define all(x) x.begin(), x.end()
#define ff first
#define ss second
#define LLINF 0x3f3f3f3f3f3f3f3f
#define INF 0x3f3f3f3f
#define uniq(x) sort(all(x)); x.resize(unique(all(x))-x.begin());
#define sz(x) (int)x.size()
#define pw(x) (1LL<<x)
using pii = pair<int, int>;
using ll = long long;
const ll MOD = 1e9 + 7;
const long double PI = acos(-1.0);
void solve() {
ll N; cin >> N;
vector<vector<vector<ll>>> dp(64, vector<vector<ll>>(2, vector<ll>(2, -1)));
function<ll(int,int,int)> dfs = [&](int digit, int flowBit, int under) {
ll& ret = dp[digit][flowBit][under];
if(ret != -1)
return ret;
ret = 0;
int curBit = (N >> digit) & 1;
if(digit == 62) {
if(flowBit == 0 && under == true)
return ret = 1LL;
else
return ret = 0LL;
}
for(int bit1 = 0 ; bit1 <= 1 ; bit1 ++) {
for(int bit2 = 0 ; bit2 <= 1 ; bit2 ++) {
if(bit2 > bit1)
continue;
int nxtBit = (bit1 + bit2 + flowBit) >= 2;
int calcedCurBit = (bit1 + bit2 + flowBit) & 1;
int newUnder = under;
if(calcedCurBit < curBit)
newUnder = 1;
else if(calcedCurBit > curBit)
newUnder = 0;
ret += dfs(digit + 1, nxtBit, newUnder);
ret %= MOD;
}
}
return ret;
};
ll ans = dfs(0, 0, 1);
cout << ans << "\n";
}
int main() {
IOS;
int t; t = 1;
while(t--)
solve();
}
|
cs |