Editorial for CCC '18 S4 - Balanced Trees


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Author: Eliden

The first phase of solving this problem involves converting the problem statement into an equation. Let f(n) be the number of perfectly balanced trees with weight n. The base case is f(1)=1. For n \ge 2, the number of subtrees is between 2 and n. If there are k subtrees, then the weight of each subtree is \left\lfloor \frac{n}{k}\right\rfloor. Since the subtrees must be completely identical, the number of possibilities is f\left(\left\lfloor \frac{n}{k}\right\rfloor\right). Thus, we have the recurrence \displaystyle f(n) = \sum_{k=2}^nf\left(\left\lfloor n/k\right\rfloor\right)

A direct DP implementation of this will take \mathcal O(N^2) time, and pass the first subtask.

One key fact to improving this time complexity is that there are \mathcal O\left(\sqrt{N}\right) distinct integers in the form \left\lfloor N/j\right\rfloor for integer j between 2 and N. This is because \left\lfloor N/j\right\rfloor \le \sqrt{N} for j \ge \sqrt{N}. This idea is related to the following two observations.

First, the number of intermediate values f(i) required to compute f(N) is significantly smaller than N. Because \left\lfloor{\left\lfloor{i/a}\right\rfloor/b}\right\rfloor=\left\lfloor{i/(ab)}\right\rfloor, we will only need to know values of f(i) where i=\left\lfloor{N/j}\right\rfloor and j is an integer between 2 and N, to find f(N) using the recursive formula.

Suppose we use this idea to only compute f(i) for the required values. Assuming we already know f(j) for all necessary j<i, computing f(i) takes \mathcal O(i) time. Thus, the time complexity for computing f(N) using this method is bounded asymptotically by \displaystyle \sum_{j=1}^N \mathcal O\left(\frac{N}{j}\right)=\mathcal O(N\log N)

This is fast enough to pass the first 3 subtasks. This algorithm is still easy to implement; one only needs a recursive function with memoization. Because of the very structured form of the necessary states, the memoization can be done simply with two arrays of size \sqrt{N} (i.e. there is no need for map data structures). To elaborate more on how this form of memoization works, note that there are really two types of numbers of the form \left\lfloor{N/j}\right\rfloor, ones greater than \sqrt{N} and ones less than or equal to it. For the latter type, indexing can be done as usual, on memory of size \sqrt{N}. For the former type, the indices are of the form \left\lfloor N/j\right\rfloor=i with i>\sqrt{N}. In this case, we can represent the i more densely using the index k=\left\lfloor N/i\right\rfloor, where for i>\sqrt{N}, it can be shown that the indexing is unique.

Although this form of memoization is probably the most sophisticated, it is hardly necessary to solve the problem fully. For all but perhaps the last subtask, it suffices to use an array of N elements. To optimize this memory consumption, one can set a size limit to some large constant (as high as possible without running out of memory), memoize the states which fit into this array, and just not memoize the other states (or alternatively, use a slower method of memoization for these few particular states). Finally, for languages with fast hash maps such as C++, using a hash map is an option that does pass the time limit, despite being noticeably slower than other approaches.

If memoized recursion takes time \mathcal O(N\log N), how fast is unmemoized recursion? Intuitively, this time is roughly proportional to the value of f(N) itself. We can show that f(n)=\mathcal O(n^c) for a particular value of c by induction. The key inductive step would be to show \displaystyle n^c \ge \sum_{k=2}^n\left(\left\lfloor{n/k}\right\rfloor\right)^c

This inequality is satisfied if \sum_{k=2}^\infty\frac{1}{k^c}=1, or equivalently, if \zeta(c)=2 (c\approx 1.729). This means that the f(N) has an upper bound of \mathcal O(N^{1.729}). Empirically, the exponent appears to be slightly lower (at least for small N), due to the error introduced by eliminating the floors and making the sum infinite in the argument above. Thus, perhaps surprisingly, naive recursion will pass the first two subtasks. Also, this indicates that the value of f(N) will fit in a 64-bit integer for all values of N given in the problem. Fun fact: if someone solving the problem were to assume that the output always fits in a 64-bit integer, then they could theoretically reverse-engineer that fact to discover that the complexity of naive recursion isn't that bad after all.

The second main observation to solving the problem fully is that the recursive formula for f(n) can be optimized to make \mathcal O(\sqrt{n}) recursive calls. Furthermore, f(n) can be computed in \mathcal O(\sqrt{n}) time if we put aside the cost to compute f for smaller numbers. This follows from the fact that there are \mathcal O(\sqrt{n}) distinct integers of the form \left\lfloor{n/k}\right\rfloor. The only catch is that to compute \sum_{k=2}^nf\left(\left\lfloor n/k\right\rfloor\right) efficiently, one must know the number of values of k for which a particular integer j equals \left\lfloor n/k\right\rfloor. After some thinking, one can discover that this quantity has the following closed form: \displaystyle \left\lfloor n/j\right\rfloor-\left\lfloor n/(j+1)\right\rfloor

Thus, we have \displaystyle f(n) = \sum_{1 \le j \le \sqrt{N}}\left(\left\lfloor n/j\right\rfloor-\left\lfloor n/(j+1)\right\rfloor\right)f(j)+\sum_{\substack{j \ge 2\\\left\lfloor n/j\right\rfloor>\sqrt{N}}}f\left(\left\lfloor n/j\right\rfloor\right)

If we use this observation on its own (that is, without any of the previous optimizations on the number of states), then the time complexity is \displaystyle \sum_{j=1}^N \mathcal O\left(\sqrt{j}\right)=\mathcal O\left(N^{3/2}\right)

This solution passes the first two subtasks.

Combined with the previous observation about the set of required states, however, the time complexity is \displaystyle \sum_{1 \le j \le \sqrt{N}}\mathcal O\left(\sqrt{j}\right)+\sum_{1 \le j \le \sqrt{N}}\mathcal O\left(\sqrt{N/j}\right)=\mathcal O\left(\left(\sqrt{N}\right)^{3/2}\right)+\mathcal O\left(\sqrt{N}\left(\sqrt{N}\right)^{1/2}\right)=\mathcal O\left(N^{3/4}\right)

Thus, we can compute f(N) in \mathcal O(N^{3/4}) time, which is good enough to pass the final subtask.

In summary, the subtasks admit the following solutions:

  1. Naive DP (N^2)
  2. Naive recursion (N^{1.729}), Observation 2 (N^{1.5})
  3. Observation 1 or memoized recursion (N\log N)
  4. Observations 1 and 2 combined (N^{3/4})

Note that analysis of these time complexities is an optional part of solving the problem. The simple input format makes it easy for contestants to test the performance of their programs to estimate if they are on the right track.


Comments


  • 14
    wuganggame  commented on June 15, 2021, 2:02 a.m. edited

    A little bit easier to understand version:

    For a root node with weight n, it can have 2, 3, 4, \dots, n subtrees:

    • f(1) = 1
    • f(n) = f(n/2) + f(n/3) + f(n/4) + \dots + f(n/n)

    That is: \displaystyle f(n) = \sum_{k=2}^n f(n/k)

    There are repeated computations in the above recursion. e.g.

    \displaystyle \begin{align*}
k &= \quad 2 \qquad 3 \qquad 4 \qquad 5 \qquad 6 \qquad 7 \qquad 8 \qquad \dots \qquad 15 \\
f(15) &= f(7) + f(5) + f(3) + f(3) + f(2) + f(2) + f(1) + \dots + f(1) \\
&= f(7) + f(5) + 2*f(3) + 2*f(2) + 8*f(1)
\end{align*}

    We only need to compute \displaystyle f(7), f(5), f(3), f(2), f(1)

    To compute the count of each f:

    1. Calculate the remainder: r = n \mathbin{\%} k (e.g. 15 \mathbin{\%} 4 = 3). This is the spaces we can allocate for d (d = n / k)
    2. Divide r by d: extra = r / d; This is the extra count we can assign
    3. Since we have at least one originally (n/k), count = extra + 1;

    We can also record previously calculated numbers with a hashtable, so that we don't calculate them again. For example, when we compute f(7), the recursion will eventually calculate f(2). We can store f(2) during the calculation of f(7). Then later when we calculate f(2) in the f(15)'s loop (k=6,7), we can use previously computed f(2) directly.