u.twoha.cc/ctf/sekaictf/ppc_nokotan.md
2024-09-13 19:49:18 -05:00

12 KiB

title date tags
Project SEKAI CTF 2024: PPC - Nokotan 2024-08-30
ctf
ppc

Task

Time limit is 2 seconds for this challenge.

https://ppc.chals.sekai.team/

nokotan.pdf

  • Author: null_awe
  • Points: 115
  • Solves: 55 / 1230 (4.471%)

Writeup

The goal of this challenge is to write a program that can determine number of possible values (sum of all node labels) of a complete binary tree of size n, where each leaf of the tree is labeled as either 0 or 1 and all other nodes are labeled with the XOR of its children.

We are first given the number of test cases t, followed by t lines, each containing the number of nodes in the tree n, where n is in the range [1, 5e5].

Therefore, our script will look something like this:

def solve(n: int) -> int:
  # todo: solve the problem
  pass

t = int(input())
for _ in range(t):
  n = int(input())
  print(solve(n))

The number of leaves in a complete binary tree is roughly half of the number of nodes, and each leaf has two options, so simply iterating over each possible tree will take exponential time.

Instead, we can calculate the possible values for the subtrees starting at the root node's children and determine the number of unique sums of their values.

First, we need to determine how many nodes are in the two subtrees. To do this, we can first ignore the last unfilled row, and then determine how the remaining nodes are distributed.

def divide(n: int) -> tuple[int, int]:
  p = 1
  # a perfect binary tree has 2 ** k - 1 nodes where k is its height+1
  while p * 2 - 1 <= n:
    p *= 2
  # now the number of nodes excluding the last unfilled row = p - 1
  remaining_nodes = n - (p - 1)
  # for each subtree, the number of nodes excluding the last unfilled row is p // 2 - 1
  perfect_subtree_nodes = p // 2 - 1
  # the left subtree can have at most p // 2 of the remaining nodes
  left_subtree = perfect_subtree_nodes + min(remaining_nodes, p // 2)
  # the right subtree does not get any of the remaining nodes until the left subtree is full
  right_subtree = perfect_subtree_nodes + max(0, remaining_nodes - p // 2)
  return left_subtree, right_subtree

Now we can implement solve as follows:

import itertools

def solve(n: int) -> int:
  return len(possible_values(n))

def possible_values(n: int) -> set[int]:
  if n == 0:
    return {0}
  if n == 1:
    return {0, 1}
  a, b = divide(n)
  a_values, b_values = possible_values(a), possible_values(b)
  return unique_sums(a_values, b_values)

def unique_sums(a: set[int], b: set[int]) -> set[int]:
  return {x + y for x, y in itertools.product(a, b)}

However, this is ignoring the value of the root node. Since the value of the root node is determined by the values of its children, we need to keep track of possible values where the root node is 1 and where the root node is 0.

def solve(n: int) -> int:
  return len(possible_values_0(n) | possible_values_1(n))

def possible_values_0(n: int) -> set[int]:
  if n <= 1:
    return {0}
  a, b = divide(n)
  a_values_0, b_values_0 = possible_values_0(a), possible_values_0(b)
  a_values_1, b_values_1 = possible_values_1(a), possible_values_1(b)
  # root node is 0 when a = 0 and b = 0 or a = 1 and b = 1
  return unique_sums(a_values_0, b_values_0) | unique_sums(a_values_1, b_values_1)

def possible_values_1(n: int) -> set[int]:
  if n == 0:
    return set()
  if n == 1:
    return {1}
  a, b = divide(n)
  a_values_0, b_values_0 = possible_values_0(a), possible_values_0(b)
  a_values_1, b_values_1 = possible_values_1(a), possible_values_1(b)
  # root node is 1 when a = 0 and b = 1 or a = 1 and b = 0
  sums = unique_sums(a_values_0, b_values_1) | unique_sums(a_values_1, b_values_0)
  # the root node adds 1 to each sum
  return {x + 1 for x in sums}

Now our solution is valid, but testing it on some of the larger values of n reveals its terrible performance. We can try applying memoization with the @functools.cache decorator (forcing us to use frozenset instead of set) which helps a bit, but we still do not meet the time constraints of the challenge.

We can see that most of the running time is spent in our unique_sums implementation. There isn't really a faster way to compute this value in general (or at least I couldn't think of a way), but the sets we pass to this function have a special property we can exploit.

A look at the possible values of the first few ns reveals the following:

n pv0 pv1
0 0 0
1 0 1
2 0 2
3 0, 2 2
4 0, 3 2, 3
5 0, 2, 3 2, 3, 4
6 0, 2, 4 3, 5
7 0, 2, 4 3, 5
8 0, 2, 3, 4, 5 3, 4, 5, 6
9 0, 2, 3, 4, 5, 6 3, 4, 5, 6, 7
10 0, 2, 4, 5, 6, 7 3, 4, 5, 6, 7, 8
11 0, 2, 4, 5, 6, 7 3, 4, 5, 6, 7, 8
12 0, 2, 3, 4, 5, 6, 7, 8 3, 4, 5, 6, 7, 8, 9
13 0, 2, 3, 4, 5, 6, 7, 8, 9 3, 4, 5, 6, 7, 8, 9
14 0, 2, 4, 6, 8, 10 4, 6, 8, 10
15 0, 2, 4, 6, 8, 10 4, 6, 8, 10

We can notice that most of these sets are just a sequence of consecutive integers, such as pv1(13), a sequence of consecutive even integers, such as pv0(15), or a sequence of consecutive even integers followed by a sequence of consecutive integers, such as pv0(11). The only exception is pv0(4).

Knowing that (most) of the inputs to unique_sums, which I will abbreviate as us from now on, will be of the previously described forms, we can compute the unique sums much more efficiently.

Let [a, b] represent the set containing all integers from a to b inclusive, and let [a, b, 2] represent the set containing all even integers from a to b inclusive. Additionally, let S(a, b, c) represent the union of [a, b, 2] and [b, c], and call b the divider of the set.

Now, it can be shown that us(A, B), where A = S(a1, a2, a3) and B = S(b1, b2, b3), will also give us a set of the form C = S(a, b, c).

First, define A1 = [a1, a2, 2], A2 = [a2, a3], B1 = [b1, b2, 2], and B2 = [b2, b3].

Then, us(A1, B1) will be equal to [a1 + b1, a2 + b2, 2]. If we fix the value chosen from B1 to be b1, we get [a1 + b1, a2 + b1, 2]. Switching our choice from B1 to be the next smallest integer b1 + 2 will only add a single new value to the output, giving us [a1 + b1, a2 + b1 + 2, 2]. We can continue with this logic up until the maximum value of B1, giving us [a1 + b1, a2 + b2, 2]. Similarly, us(A2, B2) will be equal to [a2 + b2, a3 + b3]. For us(A1, B2) and us(A2, B1), we get [a1 + b2, a2 + b3] and [b1 + a2, b2 + a3], assuming that A2/B2 contain 2 or more elements.

All of these sets except us(A1, B1) are of the form [a, b]. Thus, their union will be [min(a1 + b2, b1 + a2), a3 + b3]. We know the union be continuous since a2 + b3 and b2 + a3 (the maximum values for us(A1, B2) and us(A2, B1)) are greater than or equal to a2 + b2 (the minimum value for us(A2, B2)). Adding on us(A1, B1) gives us S(a1 + b1, min(a1 + b2, b1 + a2), a3 + b3).

Now we need to consider when A2 or B2 only contain 1 element (when a2 = a3 or b2 = b3). If A2 only has 1 element, us(A2, B1) will be [b1 + a2, b2 + a3, 2], so we use a1 + b2 as the divider instead. Likewise, if B2 only has 1 element, us(A1, B2) will be [a1 + b2, a2 + b3, 2], so we use b1 + a2 as the divider instead. If both A2 and B2 only have 1 element (both are even only), us(A1, B1) is the same as us(A, B), so we use a3 + b3 as the divider.

Writing this in code gives us the following:

def find_divider(x: frozenset[int], x_min: int, x_max: int) -> int:
  divider = x_min
  while divider < x_max and divider + 1 not in x:
    divider += 2
  return divider

def S(a: int, b: int, c: int) -> frozenset[int]:
  return frozenset(x for x in itertools.chain(range(a, b + 1, 2), range(b, c + 1)))

def unique_sums(a: frozenset[int], b: frozenset[int]) -> frozenset[int]:
  if not a or not b:
    return frozenset()
  a1, a3 = min(a), max(a)
  b1, b3 = min(b), max(b)
  a2 = find_divider(a, a1, a3)
  b2 = find_divider(b, b1, b3)
  # a and b are even only
  if a2 == a3 and b2 == b3:
    divider = a3 + b3
  # a is even only
  elif a2 == a3:
    divider = a1 + b2
  # b is even only
  elif b2 == b3:
    divider = b1 + a2
  else:
    divider = min(a1 + b2, b1 + a2)
  return S(a1 + b1, divider, a3 + b3)

As for the one exception, pv0(4) = {0, 3}, it coincidentally turns out that our implementation still works (and even if it didn't, we could just use the slower method when the set size is small).

Now we can submit the following program:

import itertools
import functools

@functools.cache
def divide(n: int) -> tuple[int, int]:
  p = 1
  # a perfect binary tree has 2 ** k - 1 nodes where k is its height+1
  while p * 2 - 1 <= n:
    p *= 2
  # now the number of nodes excluding the last unfilled row = p - 1
  remaining_nodes = n - (p - 1)
  # for each subtree, the number of nodes excluding the last unfilled row is p // 2 - 1
  perfect_subtree_nodes = p // 2 - 1
  # the left subtree can have at most p // 2 of the remaining nodes
  left_subtree = perfect_subtree_nodes + min(remaining_nodes, p // 2)
  # the right subtree does not get any of the remaining nodes until the left subtree is full
  right_subtree = perfect_subtree_nodes + max(0, remaining_nodes - p // 2)
  return left_subtree, right_subtree

@functools.cache
def solve(n: int) -> int:
  return len(possible_values_0(n) | possible_values_1(n))

@functools.cache
def possible_values_0(n: int) -> frozenset[int]:
  if n <= 1:
    return frozenset({0})
  a, b = divide(n)
  a_values_0, b_values_0 = possible_values_0(a), possible_values_0(b)
  a_values_1, b_values_1 = possible_values_1(a), possible_values_1(b)
  # root node is 0 when a = 0 and b = 0 or a = 1 and b = 1
  return unique_sums(a_values_0, b_values_0) | unique_sums(a_values_1, b_values_1)

@functools.cache
def possible_values_1(n: int) -> frozenset[int]:
  if n == 0:
    return frozenset()
  if n == 1:
    return frozenset({1})
  a, b = divide(n)
  a_values_0, b_values_0 = possible_values_0(a), possible_values_0(b)
  a_values_1, b_values_1 = possible_values_1(a), possible_values_1(b)
  # root node is 1 when a = 0 and b = 1 or a = 1 and b = 0
  sums = unique_sums(a_values_0, b_values_1) | unique_sums(a_values_1, b_values_0)
  # the root node adds 1 to each sum
  return frozenset(x + 1 for x in sums)

def find_divider(x: frozenset[int], x_min: int, x_max: int) -> int:
  divider = x_min
  while divider < x_max and divider + 1 not in x:
    divider += 2
  return divider

def S(a: int, b: int, c: int) -> frozenset[int]:
  return frozenset(x for x in itertools.chain(range(a, b + 1, 2), range(b, c + 1)))

def unique_sums(a: frozenset[int], b: frozenset[int]) -> frozenset[int]:
  if not a or not b:
    return frozenset()
  a1, a3 = min(a), max(a)
  b1, b3 = min(b), max(b)
  a2 = find_divider(a, a1, a3)
  b2 = find_divider(b, b1, b3)
  # a and b are even only
  if a2 == a3 and b2 == b3:
    divider = a3 + b3
  # a is even only
  elif a2 == a3:
    divider = a1 + b2
  # b is even only
  elif b2 == b3:
    divider = b1 + a2
  else:
    divider = min(a1 + b2, b1 + a2)
  return S(a1 + b1, divider, a3 + b3)

t = int(input())
for _ in range(t):
  n = int(input())
  print(solve(n))

Our program succeeds and gets us the flag: SEKAI{would_you_be_a_dear_for_me_and-_ok._f09fa68cf09f91a7}.