272 lines
12 KiB
Markdown
272 lines
12 KiB
Markdown
---
|
|
title: 'Project SEKAI CTF 2024: PPC - Nokotan'
|
|
date: 2024-08-30
|
|
tags: ['ctf', 'ppc']
|
|
---
|
|
## Task
|
|
|
|
> Time limit is 2 seconds for this challenge.
|
|
>
|
|
> [`https://ppc.chals.sekai.team/`](https://ppc.chals.sekai.team/)
|
|
>
|
|
> [`nokotan.pdf`](https://ctf.sekai.team/files/9da6f6acf52390532b58a9ce2ff074e7/nokotan.pdf?token=eyJ1c2VyX2lkIjozODM0LCJ0ZWFtX2lkIjo2OTYsImZpbGVfaWQiOjM0fQ.ZswikA.OQv0nG2XFvbPzEgaEuqsawWzRbU)
|
|
|
|
- `Author: null_awe`
|
|
- `Points: 115`
|
|
- `Solves: 55 / 1230 (4.471%)`
|
|
|
|
## Writeup
|
|
|
|
The goal of this challenge is to write a program that can determine number of possible values (sum of all node labels) of a complete binary tree of size `n`, where each leaf of the tree is labeled as either 0 or 1 and all other nodes are labeled with the XOR of its children.
|
|
|
|
We are first given the number of test cases `t`, followed by `t` lines, each containing the number of nodes in the tree `n`, where `n` is in the range `[1, 5e5]`.
|
|
|
|
Therefore, our script will look something like this:
|
|
|
|
```python
|
|
def solve(n: int) -> int:
|
|
# todo: solve the problem
|
|
pass
|
|
|
|
t = int(input())
|
|
for _ in range(t):
|
|
n = int(input())
|
|
print(solve(n))
|
|
```
|
|
|
|
The number of leaves in a complete binary tree is roughly half of the number of nodes, and each leaf has two options, so simply iterating over each possible tree will take exponential time.
|
|
|
|
Instead, we can calculate the possible values for the subtrees starting at the root node's children and determine the number of unique sums of their values.
|
|
|
|
First, we need to determine how many nodes are in the two subtrees. To do this, we can first ignore the last unfilled row, and then determine how the remaining nodes are distributed.
|
|
|
|
```python
|
|
def divide(n: int) -> tuple[int, int]:
|
|
p = 1
|
|
# a perfect binary tree has 2 ** k - 1 nodes where k is its height+1
|
|
while p * 2 - 1 <= n:
|
|
p *= 2
|
|
# now the number of nodes excluding the last unfilled row = p - 1
|
|
remaining_nodes = n - (p - 1)
|
|
# for each subtree, the number of nodes excluding the last unfilled row is p // 2 - 1
|
|
perfect_subtree_nodes = p // 2 - 1
|
|
# the left subtree can have at most p // 2 of the remaining nodes
|
|
left_subtree = perfect_subtree_nodes + min(remaining_nodes, p // 2)
|
|
# the right subtree does not get any of the remaining nodes until the left subtree is full
|
|
right_subtree = perfect_subtree_nodes + max(0, remaining_nodes - p // 2)
|
|
return left_subtree, right_subtree
|
|
```
|
|
|
|
Now we can implement `solve` as follows:
|
|
|
|
```python
|
|
import itertools
|
|
|
|
def solve(n: int) -> int:
|
|
return len(possible_values(n))
|
|
|
|
def possible_values(n: int) -> set[int]:
|
|
if n == 0:
|
|
return {0}
|
|
if n == 1:
|
|
return {0, 1}
|
|
a, b = divide(n)
|
|
a_values, b_values = possible_values(a), possible_values(b)
|
|
return unique_sums(a_values, b_values)
|
|
|
|
def unique_sums(a: set[int], b: set[int]) -> set[int]:
|
|
return {x + y for x, y in itertools.product(a, b)}
|
|
```
|
|
|
|
However, this is ignoring the value of the root node. Since the value of the root node is determined by the values of its children, we need to keep track of possible values where the root node is 1 and where the root node is 0.
|
|
|
|
```python
|
|
def solve(n: int) -> int:
|
|
return len(possible_values_0(n) | possible_values_1(n))
|
|
|
|
def possible_values_0(n: int) -> set[int]:
|
|
if n <= 1:
|
|
return {0}
|
|
a, b = divide(n)
|
|
a_values_0, b_values_0 = possible_values_0(a), possible_values_0(b)
|
|
a_values_1, b_values_1 = possible_values_1(a), possible_values_1(b)
|
|
# root node is 0 when a = 0 and b = 0 or a = 1 and b = 1
|
|
return unique_sums(a_values_0, b_values_0) | unique_sums(a_values_1, b_values_1)
|
|
|
|
def possible_values_1(n: int) -> set[int]:
|
|
if n == 0:
|
|
return set()
|
|
if n == 1:
|
|
return {1}
|
|
a, b = divide(n)
|
|
a_values_0, b_values_0 = possible_values_0(a), possible_values_0(b)
|
|
a_values_1, b_values_1 = possible_values_1(a), possible_values_1(b)
|
|
# root node is 1 when a = 0 and b = 1 or a = 1 and b = 0
|
|
sums = unique_sums(a_values_0, b_values_1) | unique_sums(a_values_1, b_values_0)
|
|
# the root node adds 1 to each sum
|
|
return {x + 1 for x in sums}
|
|
```
|
|
|
|
Now our solution is valid, but testing it on some of the larger values of `n` reveals its terrible performance. We can try applying memoization with the `@functools.cache` decorator (forcing us to use `frozenset` instead of `set`) which helps a bit, but we still do not meet the time constraints of the challenge.
|
|
|
|
We can see that most of the running time is spent in our `unique_sums` implementation. There isn't really a faster way to compute this value in general (or at least I couldn't think of a way), but the sets we pass to this function have a special property we can exploit.
|
|
|
|
A look at the possible values of the first few `n`s reveals the following:
|
|
|
|
n | pv0 | pv1
|
|
----|---------------------------|---------------------
|
|
0 | 0 | 0
|
|
1 | 0 | 1
|
|
2 | 0 | 2
|
|
3 | 0, 2 | 2
|
|
4 | 0, 3 | 2, 3
|
|
5 | 0, 2, 3 | 2, 3, 4
|
|
6 | 0, 2, 4 | 3, 5
|
|
7 | 0, 2, 4 | 3, 5
|
|
8 | 0, 2, 3, 4, 5 | 3, 4, 5, 6
|
|
9 | 0, 2, 3, 4, 5, 6 | 3, 4, 5, 6, 7
|
|
10 | 0, 2, 4, 5, 6, 7 | 3, 4, 5, 6, 7, 8
|
|
11 | 0, 2, 4, 5, 6, 7 | 3, 4, 5, 6, 7, 8
|
|
12 | 0, 2, 3, 4, 5, 6, 7, 8 | 3, 4, 5, 6, 7, 8, 9
|
|
13 | 0, 2, 3, 4, 5, 6, 7, 8, 9 | 3, 4, 5, 6, 7, 8, 9
|
|
14 | 0, 2, 4, 6, 8, 10 | 4, 6, 8, 10
|
|
15 | 0, 2, 4, 6, 8, 10 | 4, 6, 8, 10
|
|
|
|
We can notice that most of these sets are just a sequence of consecutive integers, such as `pv1(13)`, a sequence of consecutive even integers, such as `pv0(15)`, or a sequence of consecutive even integers followed by a sequence of consecutive integers, such as `pv0(11)`. The only exception is `pv0(4)`.
|
|
|
|
Knowing that (most) of the inputs to `unique_sums`, which I will abbreviate as `us` from now on, will be of the previously described forms, we can compute the unique sums much more efficiently.
|
|
|
|
Let `[a, b]` represent the set containing all integers from `a` to `b` inclusive, and let `[a, b, 2]` represent the set containing all even integers from `a` to `b` inclusive. Additionally, let `S(a, b, c)` represent the union of `[a, b, 2]` and `[b, c]`, and call `b` the divider of the set.
|
|
|
|
Now, it can be shown that `us(A, B)`, where `A = S(a1, a2, a3)` and `B = S(b1, b2, b3)`, will also give us a set of the form `C = S(a, b, c)`.
|
|
|
|
First, define `A1 = [a1, a2, 2]`, `A2 = [a2, a3]`, `B1 = [b1, b2, 2]`, and `B2 = [b2, b3]`.
|
|
|
|
Then, `us(A1, B1)` will be equal to `[a1 + b1, a2 + b2, 2]`. If we fix the value chosen from `B1` to be `b1`, we get `[a1 + b1, a2 + b1, 2]`. Switching our choice from `B1` to be the next smallest integer `b1 + 2` will only add a single new value to the output, giving us `[a1 + b1, a2 + b1 + 2, 2]`. We can continue with this logic up until the maximum value of `B1`, giving us `[a1 + b1, a2 + b2, 2]`. Similarly, `us(A2, B2)` will be equal to `[a2 + b2, a3 + b3]`. For `us(A1, B2)` and `us(A2, B1)`, we get `[a1 + b2, a2 + b3]` and `[b1 + a2, b2 + a3]`, assuming that `A2/B2` contain 2 or more elements.
|
|
|
|
All of these sets except `us(A1, B1)` are of the form `[a, b]`. Thus, their union will be `[min(a1 + b2, b1 + a2), a3 + b3]`. We know the union be continuous since `a2 + b3` and `b2 + a3` (the maximum values for `us(A1, B2)` and `us(A2, B1)`) are greater than or equal to `a2 + b2` (the minimum value for `us(A2, B2)`). Adding on `us(A1, B1)` gives us `S(a1 + b1, min(a1 + b2, b1 + a2), a3 + b3)`.
|
|
|
|
Now we need to consider when `A2` or `B2` only contain 1 element (when `a2 = a3` or `b2 = b3`). If `A2` only has 1 element, `us(A2, B1)` will be `[b1 + a2, b2 + a3, 2]`, so we use `a1 + b2` as the divider instead. Likewise, if `B2` only has 1 element, `us(A1, B2)` will be `[a1 + b2, a2 + b3, 2]`, so we use `b1 + a2` as the divider instead. If both `A2` and `B2` only have 1 element (both are even only), `us(A1, B1)` is the same as `us(A, B)`, so we use `a3 + b3` as the divider.
|
|
|
|
Writing this in code gives us the following:
|
|
|
|
```python
|
|
def find_divider(x: frozenset[int], x_min: int, x_max: int) -> int:
|
|
divider = x_min
|
|
while divider < x_max and divider + 1 not in x:
|
|
divider += 2
|
|
return divider
|
|
|
|
def S(a: int, b: int, c: int) -> frozenset[int]:
|
|
return frozenset(x for x in itertools.chain(range(a, b + 1, 2), range(b, c + 1)))
|
|
|
|
def unique_sums(a: frozenset[int], b: frozenset[int]) -> frozenset[int]:
|
|
if not a or not b:
|
|
return frozenset()
|
|
a1, a3 = min(a), max(a)
|
|
b1, b3 = min(b), max(b)
|
|
a2 = find_divider(a, a1, a3)
|
|
b2 = find_divider(b, b1, b3)
|
|
# a and b are even only
|
|
if a2 == a3 and b2 == b3:
|
|
divider = a3 + b3
|
|
# a is even only
|
|
elif a2 == a3:
|
|
divider = a1 + b2
|
|
# b is even only
|
|
elif b2 == b3:
|
|
divider = b1 + a2
|
|
else:
|
|
divider = min(a1 + b2, b1 + a2)
|
|
return S(a1 + b1, divider, a3 + b3)
|
|
```
|
|
|
|
As for the one exception, `pv0(4) = {0, 3}`, it coincidentally turns out that our implementation still works (and even if it didn't, we could just use the slower method when the set size is small).
|
|
|
|
Now we can submit the following program:
|
|
|
|
```python
|
|
import itertools
|
|
import functools
|
|
|
|
@functools.cache
|
|
def divide(n: int) -> tuple[int, int]:
|
|
p = 1
|
|
# a perfect binary tree has 2 ** k - 1 nodes where k is its height+1
|
|
while p * 2 - 1 <= n:
|
|
p *= 2
|
|
# now the number of nodes excluding the last unfilled row = p - 1
|
|
remaining_nodes = n - (p - 1)
|
|
# for each subtree, the number of nodes excluding the last unfilled row is p // 2 - 1
|
|
perfect_subtree_nodes = p // 2 - 1
|
|
# the left subtree can have at most p // 2 of the remaining nodes
|
|
left_subtree = perfect_subtree_nodes + min(remaining_nodes, p // 2)
|
|
# the right subtree does not get any of the remaining nodes until the left subtree is full
|
|
right_subtree = perfect_subtree_nodes + max(0, remaining_nodes - p // 2)
|
|
return left_subtree, right_subtree
|
|
|
|
@functools.cache
|
|
def solve(n: int) -> int:
|
|
return len(possible_values_0(n) | possible_values_1(n))
|
|
|
|
@functools.cache
|
|
def possible_values_0(n: int) -> frozenset[int]:
|
|
if n <= 1:
|
|
return frozenset({0})
|
|
a, b = divide(n)
|
|
a_values_0, b_values_0 = possible_values_0(a), possible_values_0(b)
|
|
a_values_1, b_values_1 = possible_values_1(a), possible_values_1(b)
|
|
# root node is 0 when a = 0 and b = 0 or a = 1 and b = 1
|
|
return unique_sums(a_values_0, b_values_0) | unique_sums(a_values_1, b_values_1)
|
|
|
|
@functools.cache
|
|
def possible_values_1(n: int) -> frozenset[int]:
|
|
if n == 0:
|
|
return frozenset()
|
|
if n == 1:
|
|
return frozenset({1})
|
|
a, b = divide(n)
|
|
a_values_0, b_values_0 = possible_values_0(a), possible_values_0(b)
|
|
a_values_1, b_values_1 = possible_values_1(a), possible_values_1(b)
|
|
# root node is 1 when a = 0 and b = 1 or a = 1 and b = 0
|
|
sums = unique_sums(a_values_0, b_values_1) | unique_sums(a_values_1, b_values_0)
|
|
# the root node adds 1 to each sum
|
|
return frozenset(x + 1 for x in sums)
|
|
|
|
def find_divider(x: frozenset[int], x_min: int, x_max: int) -> int:
|
|
divider = x_min
|
|
while divider < x_max and divider + 1 not in x:
|
|
divider += 2
|
|
return divider
|
|
|
|
def S(a: int, b: int, c: int) -> frozenset[int]:
|
|
return frozenset(x for x in itertools.chain(range(a, b + 1, 2), range(b, c + 1)))
|
|
|
|
def unique_sums(a: frozenset[int], b: frozenset[int]) -> frozenset[int]:
|
|
if not a or not b:
|
|
return frozenset()
|
|
a1, a3 = min(a), max(a)
|
|
b1, b3 = min(b), max(b)
|
|
a2 = find_divider(a, a1, a3)
|
|
b2 = find_divider(b, b1, b3)
|
|
# a and b are even only
|
|
if a2 == a3 and b2 == b3:
|
|
divider = a3 + b3
|
|
# a is even only
|
|
elif a2 == a3:
|
|
divider = a1 + b2
|
|
# b is even only
|
|
elif b2 == b3:
|
|
divider = b1 + a2
|
|
else:
|
|
divider = min(a1 + b2, b1 + a2)
|
|
return S(a1 + b1, divider, a3 + b3)
|
|
|
|
t = int(input())
|
|
for _ in range(t):
|
|
n = int(input())
|
|
print(solve(n))
|
|
```
|
|
|
|
Our program succeeds and gets us the flag: `SEKAI{would_you_be_a_dear_for_me_and-_ok._f09fa68cf09f91a7}`.
|