Note: in the hope of speeding up the process of writing walkthroughs each day, this year I am not going to give a brief summary of the "part 1" problem statement at the beginning of each day. Instead, I will jump right at the solution. The official problem statements are linked throughout the document for reference.
- Day 1 - Trebuchet?!
- Day 2 - Cube Conundrum
- Day 3 - Gear Ratios
- Day 4 - Scratchcards
- Day 5 - If You Give A Seed A Fertilizer
- Day 6 - Wait For It
- Day 7 - Camel Cards
- Day 8 - Haunted Wasteland
- Day 9 - Mirage Maintenance
- Day 10 - TODO
- Day 11 - Cosmic Expansion
- Day 12 - Hot Springs
- Day 13 - Point of Incidence
- Day 14 - Parabolic Reflector Dish
- Day 15 - Lens Library
- Day 16 - The Floor Will Be Lava
- Day 17 - Clumsy Crucible
- Day 18 - Lavaduct Lagoon
- Day 19 - Aplenty
- Day 20 - Pulse Propagation
Problem statement — Complete solution — Back to top
Task seems easy enough. How do you find out if a character is a digit? Simply
check char.isdigit()
. We can do this for each character of
each line of input, first iterating forward to find the first, and then
iterating backwards (using [::-1]
) to find the last. The digits we find will
need to be converted to int
, and the first one will need to also be multiplied
by 10
.
fin = open(...)
total = 0
for line in fin:
for char in line:
if char.isdigit():
total += 10 * int(char)
break
for char in line[::-1]:
if char.isdigit():
total += int(char)
break
We can simplify this with the help of the filter()
built-in function: just filter out any character that satisfies str.isdigit()
.
To only extract the first such character from the iterator returned by
filter()
we can simply use next()
.
for line in fin:
digit1 = next(filter(str.isdigit, line))
digit2 = next(filter(str.isdigit, line[::-1]))
total += 10 * int(digit1) + int(digit2)
print('Part 1:', total)
Things get more complex and this is probably the "hardest" day 1 problem I have seen so far. We need to also consider English words when checking each line of input. The first and last digits to appear either as a digit or as an english word need to be found.
There isn't much to do except checking each spelled out English digit for each
line. We can simplify things by building a dict
to use as a lookup table:
DIGITS = {
'one' : 1,
'two' : 2,
'three': 3,
'four' : 4,
'five' : 5,
'six' : 6,
'seven': 7,
'eight': 8,
'nine' : 9,
}
Now the check is a bit more annoying, so let's create a function for it: it will
take a string and will check whether the first character is a digit (and in that
case return it) or whether the string starts with a spelled-out English digit
(and in that case convert and return it). We'll return 0
in case of no match
for simplicity.
def check_digit(string):
if string[0].isdigit():
return int(string[0])
for d in DIGITS:
if string.startswith(d):
return DIGITS[d]
return 0
The second loop above can again be simplified with the use of filter()
+
next()
, but since this time we are not guaranteed to find an English word in
string
, we need to pass a second argument to next()
for the default value to
return in case filter()
does not match anything.
def check_digit(char, string):
if string[0].isdigit():
return int(string[0])
d = next(filter(string.startswith, DIGITS), None)
return DIGITS.get(d, 0)
We can now integrate the above function in the loop we wrote for part 1, using a
second variable for the total. This time, we'll have to iterate over each
possible substring manually, first forward and then backwards. We can easily use
range()
for that.
total1 = total2 = 0
for line in fin:
# Part 1
total1 += 10 * int(next(filter(str.isdigit, line)))
total1 += int(next(filter(str.isdigit, line[::-1])))
# Part 2
for i in range(len(line)):
digit1 = check_digit(char, line[i:])
if digit1:
break
for i in range(len(line) - 1, -1, -1):
digit2 = check_digit(line[i], line[i:])
if digit2:
break
total2 += 10 * digit1 + digit2
print('Part 1:', total1)
print('Part 2:', total2)
There is technically a way to "simplify" this even more, again with the use of
filter()
+ next()
, but it does not really help anything. However, here it
is, just for the fun of it:
for line in fin:
total2 += 10 * next(filter(None, map(check_digit, (line[i:] for i in range(len(line))))))
total2 += next(filter(None, map(check_digit, (line[i:] for i in range(len(line) -1, -1, -1)))))
First two stars of the year done. Welcome to Advent of Code 2023!
Problem statement — Complete solution — Back to top
We are dealing with... a game using colored cubes. Today's input seems a bit
annoying to parse. While we could do this with a couple of regular expressions,
I usually prefer to avoid those. Thankfully we can get around with only a few
.split()
operations.
The lines are in the form:
Game 3: 1 red, 2 green; 3 green, 1 blue, 7 red; 3 green, 5 red, 1 blue
The first thing to do is extract the game ID and convert it to integer: we can
do this by simply splitting on the colon (:
) - or more precisely on colon
followed by a space (': '
) - and then extract the ID by slicing.
fin = open(...)
for line in fin:
gid, game = line.split(': ')
gid = int(gid[5:])
We can then split on '; '
to separate each "turn" of each game, split again on
', '
to separate each listed number and color combination in the turn, and
again one last time on whitespace to separate the number from the color,
converting numbers to integer.
for line in fin:
gid, game = line.split(': ')
gid = int(gid[5:])
for turn in game.split('; '):
for n_and_color in turn.split(', '):
n, color = n_and_color.split()
n = int(n)
We can simplify the above and split n
and color
on the fly using
map()
and str.split()
:
for turn in game.split('; '):
+ for n, color in map(str.split, turn.split(', ')):
- for n_and_color in turn.split(', '):
- n, color = n_and_color.split()
n = int(n)
Now, since we are tasked with identifying impossible games, for each color we
encounter in a turn, we need to check whether the given number exceeds the
limits we are given: 12 red, 13 green and 14 blue. We can do this pretty easily
with a few if
statements and a boolean variable.
answer = 0
for line in fin:
gid, game = line.split(': ')
gid = int(gid[5:])
bad = False
for turn in game.split('; '):
for n, color in map(str.split, turn.split(', ')):
n = int(n)
if color == 'red' and n > 12:
bad = True
elif color == 'green' and n > 13:
bad = True
elif color == 'blue' and n > 14:
bad = True
if bad:
answer += gid
Those if
statements can be simplified down to a single one by combining the
conditions. Additionally, we can break
out of the loop as soon as we find an
impossible turn:
# ... same as above ...
if n > 14 or (n > 13 and color == 'green') or (n > 12 and color == 'red'):
bad = True
break
if bad:
answer += gid
print('Part 1:', answer)
For part two, for each game we now need to find the minimum number of cubes of each color that would make the game possible. If we think about it for a moment, this simply means computing the maximum value we see for each color over all the turns of a game.
We can adapt the code we just wrote for part 1 and integrate the calculations
for part 2 too. We'll add 3 more variables to keep track of the maximum number
for each color among all the turns using max()
.
answer1 = answer2 = 0
for line in fin:
gid, game = line.split(': ')
gid = int(gid[5:])
# Maximum number of red, green and blue cubes seen in any turn of this game
maxr = maxg = maxb = 0
for turn in game.split('; '):
for n, color in map(str.split, turn.split(', ')):
n = int(n)
if color == 'red':
maxr = max(n, maxr)
elif color == 'green':
maxg = max(n, maxg)
else:
maxb = max(n, maxb)
For each game, at the end of the for turn
loop, we now have maxr
, maxg
,
and maxb
containing respectively the maximum number of red, green and blue
cubes that we saw in any turn. We can use these values to simplify the logic for
the impossibility check of part 1, which can now happen outside the loop, as we
cannot break
earlier as we did before. Since the check is moved outside the
loop, we also don't need the bad
boolean variable anymore.
for line in fin:
# ...
for turn in game.split('; '):
# ...
if maxr <= 12 and maxg <= 13 and maxb <= 14:
answer1 += gid
Or, if we want to avoid branching, we can use a simple multiplication, since a
bool
used in arithmetic expressions evaluates either to 0
or 1
:
- if maxr <= 12 and maxg <= 13 and maxb <= 14:
- answer1 += gid
+ answer1 += gid * (maxr <= 12 and maxg <= 13 and maxb <= 14)
Finally, the value we want for part 2 simply consists of the sum of the products of the 3 maximums over all games:
for line in fin:
# ...
for turn in game.split('; '):
# ...
answer1 += gid * (maxr <= 12 and maxg <= 13 and maxb <= 14)
answer2 += maxr * maxg * maxb
print('Part 1:', answer1)
print('Part 2:', answer2)
And here we go, 4 stars and counting!
Problem statement — Complete solution — Back to top
First ASCII grid problem of the year, let's get right into it! First of all,
parsing: doesn't get much easier than this, we can just read the entire input in
one go and the use .splitlines()
to have a nice list of
strings that we can index like a grid. For later use, let's also calculate width
and height of the grid right away.
fin = open(...)
grid = fin.read().splitlines()
witdh, height = len(grid), len(grid[0])
Now onto the real problem: given that the numbers scattered around the grid are
always spelled from left to right (i.e. all the digits are always on the same
row), in order to extract a number we can simply scan linearly until we stop
seeing digits. Let's write a function to extract a number in this way: it will
take the row and the starting column as parameters and return a number converted
to integer. For simplicity, we'll also pass the row length since we have it at
hand. The .isdigit()
method of strings comes in handy
(technically, .isdigit()
doesn't only accept ASCII digits, but
we know our entire input is ASCII, so it's fine).
def extract_number(row, start, length):
end = start + 1
while end < length and row[end].isdigit():
end += 1
return int(row[start:end])
Now all we need is a way to determine whether a number is adjacent to any symbol. There are a multitude of ways to go about this, with small variations that can make the code look completely different. I chose the one that seemed more readable and intuitive to me.
Following a similar approach as the one used to extract a number, we can scan linearly from left to right starting from one column before the number and stopping one column after the number, checking the row containing the number as well as the one above and below. If we find any symbol, we can stop checking and we know that number needs to be accounted for.
Let's write another function to do this:
- Start from the column before the first digit and keep going until we have a digit on the given row.
- Check the same column in the row above and below the given row: if there's a symbols, stop here.
- Check the column of the given row, if there's a symbol, stop here. Otherwise stop unconditionally if there's no digit and we are past the starting column.
Point 3 above is a bit tricky, but is basically what allows us to stop scanning when we encounter the end of the current number.
We don't exactly know what symbols there might be, but we know that anything
that is not a dot (.
) or a digit is a symbol, so that's the check we'll
implement.
Here's the function:
def has_adjacent_symbols(grid, r, start_c, height, width):
# Avoid going out of bounds on the left if start_c is 0
start_c = max(0, start_c - 1)
for c in range(start_c, width):
# Check above
if r > 0 and grid[r - 1][c] not in '0123456789.':
return True
# Check below
if r < height - 1 and grid[r + 1][c] not in '0123456789.':
return True
# Check given row
if not grid[r][c].isdigit():
# Symbol
if grid[r][c] != '.':
return True
# No more digits, stop
if c > start_c:
break
return False
Since the row we are given (r
) is always the same, and indexing lists is a
pretty slow operation in Python, we can simplify the above function while also
making it more readable by extracting the three rows (given row, row above and
row below) right away. A conditional expression comes in handy
for this purpose.
def has_adjacent_symbols(grid, r, start_c, h, w):
row = grid[r]
# Take row above if possible, else an empty list
above = grid[r - 1] if r > 0 else []
# Take row below if possible, else an empty list
below = grid[r + 1] if r < height - 1 else []
for c in range(max(0, start_c - 1), width):
if above and above[c] not in '0123456789.':
return True
if below and below[c] not in '0123456789.':
return True
if not row[c].isdigit():
if row[c] != '.':
return True
# No more digits, stop
if c > start_c:
break
return False
The multiple not in ...
checks may seem a bit redundant, but there aren't many
alternatives that are shorter or simpler to read.
We now have all we need to solve the problem. All that's left to do is iterate over the grid one row at a time, and scan characters in the row until we find a digit. Once we do, check for adjacent symbols and extract the number if any symbol is found.
The enumerate()
built-in is nice to have to iterate
both on rows and their row index at once.
answer = 0
for r, row in enumerate(grid):
c = 0
while c < width:
# Skip all non-digits
while c < width and not row[c].isdigit():
c += 1
# Stop if we are past the end of the row
if c >= width:
break
# We have a digit, check for adjacent symbols and extract it
if has_adjacent_symbols(grid, r, c, height, width):
answer += extract_number(row, c, width)
# Skip remaining digits
while c < width and row[c].isdigit():
c += 1
print('Part 1:', answer)
The task now becomes a bit more complex: we need to identify "gears", that is:
all the asterisk (*
) symbols that have exactly two adjacent numbers. For each
such pair of numbers, we then need to calculate the product, and sum all the
products up.
It may seem like a lot of work, but the way we wrote the
has_adjacent_symbols()
function for part one makes it easy to modify it to
identify gears instead of just stopping at the first symbol and returning a
boolean. We can transform the function to return a boolean plus a list of
coordinates for asterisk symbols found. We can then use those coordinates (row
and column index) as an unique identifier of a given asterisk: each time we
encounter the same coordinates it means we found the same one.
Let's turn has_adjacent_symbols()
into adjacent_symbols()
:
def adjacent_symbols(grid, r, start_c, height, width):
row = grid[r]
# Take row above if possible, else an empty list
above = grid[r - 1] if r > 0 else []
# Take row below if possible, else an empty list
below = grid[r + 1] if r < height - 1 else []
# List of coordinates of adjacent '*' symbols found
gears = []
# True if any adjacent symbol is found
adj_to_symbol = False
for c in range(max(0, start_c - 1), width):
if above and above[c] not in '0123456789.':
adj_to_symbol = True
if above[c] == '*':
gears.append((r - 1, c))
if below and below[c] not in '0123456789.':
adj_to_symbol = True
if below[c] == '*':
gears.append((r + 1, c))
if not row[c].isdigit():
adj_to_symbol |= row[c] != '.'
if row[c] == '*':
gears.append((r, c))
# No more digits, stop
if c > start_c:
break
return adj_to_symbol, gears
Now a call to adjacent_symbols()
returns a boolean indicating whether any
symbols were found adjacent to the number starting at start_c
in the row at
index r
, as well as the coordinates of all the *
symbols encountered. We can
keep track of the numbers adjacent to a given *
symbol with a dictionary of
lists. Using a defaultdict(list)
makes it even
easier as we can just .append()
without worrying if a given
symbol was already added to the dictionary or not.
from collections import defaultdict
# For each '*' symbol, keep a list holding its adjacent numbers
gears = defaultdict(list)
The main loop needs minimal modifications:
answer1 = 0
for r, row in enumerate(grid):
c = 0
while c < width:
# Skip all non-digits
while c < width and not row[c].isdigit():
c += 1
# Stop if we are past the end of the row
if c >= width:
break
adj_to_symbol, adj_gears = list(adjacent_symbols(grid, r, c, height, width))
if adj_to_symbol:
number = extract_number(row, c, width)
answer1 += number
# For each '*' symbol, add the current number to the list of numbers
# adjacent to the symbol
for coords in adj_gears:
gears[coords].append(number)
# Skip remaining digits
while c < width and row[c].isdigit():
c += 1
print('Part 1:', answer1)
We now have a complete gears
dictionary of the form {(r, c): [num, ...]}
,
for example:
{
(1, 2): [123, 456],
(3, 4): [999],
# ...
}
We already have the part 1 answer. For part 2 we are asked to consider *
symbols that have exactly two adjacent numbers, so we can simply iterate over
gears
and check which lists have a length of 2
. Simple enough!
answer2 = 0
for numbers in gears.values():
if len(numbers) == 2:
answer2 += numbers[0] * numbers[1]
We can also use math.prod()
to calculate the product:
from math import prod
for numbers in adj_numbers.values():
if len(numbers) == 2:
answer2 += prod(numbers)
Since all we are doing is summing up, we can simplify even more with the help of a few builtins:
filter()
each list with alambda
function that checks the length, to only extract the ones with length2
.map()
each list to the product of its numbers withprod()
.sum()
up all the products.
len_2_lists = filter(lambda x: len(x) == 2, gears.values())
products = map(prod, len_2_lists)
answer2 = sum(products)
Or more concisely:
answer2 = sum(map(prod, filter(lambda x: len(x) == 2, gears.values()))))
print('Part 2:', answer2)
Although I like the conciseness of the above expression, one may prefer the more verbose loop since it is undoubtedly easier to understand if you are not an hardcore Python dev. Well, a bit of golfing is always fun anyway, so I'll keep it as is. Six stars out of fifty!
Problem statement — Complete solution — Back to top
For each line of input we have two easy to parse lists of integers, and we want to know how many integers of the first list are also in the second one. We can do this with a simple loop while parsing the input.
To extract the lists we can start by discarding anythig before the first :
,
whose index can be found with .find()
. Then,
.split()
the line on |
(the delimiter of the two lists),
split again each list on whitespace, and finally map()
everything to int
.
fin = open(...)
for line in fin:
winners, numbers = line[line.find(':') + 1:].split('|')
winners = list(map(int, winners.split()))
numbers = list(map(int, numbers.split()))
Now we have two lists for each input line that are easy to work with: to check
how many elements of winners
are in numbers
the first thing that comes to
mind is simply iterating over the former and checking if elements are in the
latter using the in
operator:
matches = 0
for w in winners:
if w in numbers:
matches += 1
This isn't really optimal though, as the in
operator has to scan numbers
every single time. It would be better if numbers
was a set
, so that
membership could be tested instantly. In fact... if both winners
and numbers
are set
s, we can use the &
(binary AND) operator to calculate the
intersection of the two sets. The length of the
intersection will then be equal to the number of elements that are in both sets,
which is what we want.
for line in fin:
winners, numbers = line[line.find(':') + 1:].split('|')
winners = set(map(int, winners.split()))
numbers = set(map(int, numbers.split()))
matches = len(numbers & winners)
Now, as the rules of the game suggest, for each card, each matching number
doubles the score of the card, which starts at 1
on the first match. This is
basically just a power of two, so we can calculate it with one expression
without the need of loops. Doing 2**(matches - 1)
seems enough, but in case of
zero matches we will get 2**-1 == 0.5
while we would want to have 0
instead.
We can transform the result to int
to work around this.
score = 0
for line in fin:
winners, numbers = line[line.find(':') + 1:].split('|')
winners = set(map(int, winners.split()))
numbers = set(map(int, numbers.split()))
matches = len(numbers & winners)
score += int(2**(matches - 1))
print('Part 1:', score)
We don't need to calculate a score anymore, but instead for each card we need to duplicate the N cards past the current one, where N is the number of matches of the current card. So for example, if card 1 has 3 matches, we'll need to duplicate card 2, 3 and 4. After doing this for all cards, we want to know how many cards we end up with in total (originals and copies).
The peculiar thing is that we need to account for the copies as well, so each time we process a card, every single copy of that card will generate a new copy of the next N cards. Therefore, in general, the number of matches must be multiplied by the number of copies of the card we have.
Let's keep track of the number of matches calculated in part 1 with a list
:
+matches = []
+
for line in fin:
winners, numbers = line[line.find(':') + 1:].split('|')
winners = set(map(int, winners.split()))
numbers = set(map(int, numbers.split()))
- matches = len(numbers & winners)
- score += int(2**(matches - 1))
+ matches.append(len(numbers & winners))
The total score for part 1 could still be calculated inside the loop, or outside
the loop with a sum()
and a
generator expression:
score = sum(int(2**(n - 1)) for n in matches)
print('Part 1:', score)
To keep track of the number of copies of each card, we can either use a dict
or a list
. Since we already know that we have exactly len(matches)
cards, we
can just use a list
. Initially, we have 1
copy of each card:
copies = [1] * len(matches)
For each card i
we process, we need to add one copy of the next matches[i]
cards. Since the card itself could have already been copied, instead of adding
only one copy of the next cards, we'll add copies[i]
instead. We can
enumerate()
the matches to have both the current
card's index and its number of matches ready. Other than that, it's just a
matter of a couple of for
loops:
for i, n in enumerate(matches):
for j in range(i + 1, i + n + 1):
copies[j] += copies[i]
The answer we are after is the total number of cards (copies included), so we
can just sum()
up the number of copies:
total = sum(copies)
print('Part 2:', total)
Problem statement — Complete solution — Back to top
Interesting problem today! And unfortunately also some kind of annoying input parsing... so let's get it out of the way quickly.
Our input is split in "sections" that are delimited by empty lines. The first section is simply a list of seeds (integers), while the other sections represent "mappings". Each mapping has a header (first line) followed by a list of entries that are 3-tuples of integers, one per line. With enough splitting and mapping, we should be able to do it:
.split()
the whole input in sections: simple enough, just need to split on two consecutive newlines ('\n\n'
).- Parse the seeds, simply discarding the initial
'seeds:'
,.split()
on whitespace andmap()
the numbers toint
. - Parse the mappings one at a time: for each one, discard the header and then
build a list of 3-tuples, again splitting on whitespace and converting to
int
.
This time I decided to write a function for input parsing, here it is:
def parse_input(fin):
sections = fin.read().split('\n\n')
seeds = sections[0]
seeds = list(map(int, seeds[6:].split()))
mappings = []
for section in sections[1:]:
mapping = []
mappings.append(mapping)
for line in section.splitlines()[1:]:
entry = tuple(map(int, line.split()))
mapping.append(entry)
return seeds, mappings
The entries in each mapping are given in the form dst src length
meaning that
numbers in the source range [src, src + length)
should map to the destination
range [dst, dst + length)
. This simply means that if a number is in the source
range, a delta needs to be applied to bring it to the destination range: that
delta is simply dst - src
. Since it's easier to work with source range and
delta, let's just convert each entry in the mapping accordingly right at the
moment of parsing:
def parse_input(fin):
# ... unchanged ...
for line in section.splitlines()[1:]:
- entry = tuple(map(int, line.split()))
- mapping.append(entry)
+ dst, src, length = map(int, line.split())
+ mapping.append((src, src + length, dst - src))
# ... unchanged ...
The input can now be easily parsed:
fin = open(...)
seeds, mappings = parse_input(fin)
We now need to apply all the mappings (in order) to each of the seeds
we have,
and find the smallest final value (after all mappings are applied). We can do
this quite easily:
- For each seed value
v
, iterate over all mappings one by one. - For each entry
start, end, delta
in the mapping: ifv
is in the range[start, end)
, thenv + delta
is the new value and we can stop scanning the entries of this mapping and proceed to the next one. Otherwise,v
remains unchanged.
We can use float('inf')
for an initial "infinite" positive value as minimum.
Putting the above into code we already have a complete solution:
minimum = float('inf')
for v in seeds:
for step in steps:
for start, end, delta in step:
if start <= v < end:
v += delta
break
minimum = min(minimum, v)
print('Part 1:', minimum)
The seed
numbers we were initially given now need to be interpreted
differently: they are pairs of the form start_seed n
. Each pair represents the
seeds in the range [start_value, start_value + n)
. The request is unchanged,
but now we have a whole lot more seeds to work with.
If we take a look at our input, we can see numbers that are easily in the hundreds of millions. Needless to say, a bruteforce approach (simply checking all seeds in each range), is going to be way too slow. We could maybe get away with it a compiled programming language like C/C++/Go/Rust, but (1) we are stuck with Python 3 and (2) that's just lame, we want to find the real solution!
The way to go about this is to write a new algorithm that is able to work with
segments instead of single values. If we think about it, for every input
segment [A, B)
, for every source range [C, D)
of a given mapping, we can
only have four different possibilities of overlap:
(1) Complete (2) Partial inner
AxxxB A----xxx----B
C----xxx----D CxxxD
(3) Partial right (4) Partial left
A----xxxB Axxx----B
Cxxx----D C----xxxD
While processing a segment, for each segment (source range) in the mapping, if
we have overlap, we can "extract" the overlapping and non-overlapping parts: the
overlapping part can be converted to the destination range by applying delta
on both its ends, while the non-overlapping parts can be kept to check if it
overlaps with other segments.
The first thing to do now is to convert the initial list of seeds
into a list
of segments. We can do this easily if we iterate its elements pairwise:
segments = []
for i in range(0, len(seeds), 2):
start_value, n = seeds[i:i + 2]
segments.append((start_value, start_value + n))
An alternative way to do this is through the use of zip()
:
segments = []
for start_value, n in zip(seeds[::2], seeds[1::2]):
segments.append((start_value, start_value + n))
The "trick" here is that seeds[::2]
returns a list of the elements with even
indices, and seeds[1::2]
returns a list of the elements with odd indices.
Using a generator expression we can also get rid of the loop:
segments = [(v, v + n) for v, n in zip(seeds[::2], seeds[1::2])]
The algorithm to implement is now as follows:
- Start with a list of segments to convert.
- For each mapping:
- Create an empty list of processed segments (to be converted by the next mapping).
- For each segment we need to convert:
- Check for overlaps with each segment of the mapping:
- In case of overlap, extract the overlapping part, shift it by
delta
and move it to the list of processed segments. Then extract the non-overlapping part(s) and add them back to the list of segments to convert (they may overlap with other segments in this mapping). - In case of no overlap with any segment of the mapping, move the segment as is to the list of processed segments.
- In case of overlap, extract the overlapping part, shift it by
- Check for overlaps with each segment of the mapping:
- Take the list of processed segments as the new list of segments to convert.
Let's write a function to implement the above. In order to perform fast removal
and insertion from the beginning, we can use a deque
for both the list of segments to convert and the list of processed segments.
from collections import deque
def solve(segments, mappings):
for mapping in mappings:
processed = deque()
while segments:
a, b = segments.popleft()
for c, d, delta in mapping:
partial_left = c <= a < d
partial_right = c < b <= d
if partial_left and partial_right:
# Complete overlap:
# a---b
# c-----------d
# Entire [a, b) segment is converted
processed.append((a + delta, b + delta))
break
if partial_left:
# Partial left overlap:
# a------b
# c------d
# [a, d) is converted
processed.append((a + delta, d + delta))
# [d, b) may overlap with other segments, keep it
segments.append((d, b))
break
if partial_right:
# Partial right overlap:
# a------b
# c------d
# [a, d) is converted
processed.append((c + delta, b + delta))
# [a, c) may overlap with other segments, keep it
segments.append((a, c))
break
if a < c and b > d:
# Partial inner overlap:
# a-----------b
# c---d
# [c, d) is converted
processed.append((c + delta, d + delta))
# [a, c) may overlap with other segments, keep it
segments.append((a, c))
# [d, b) may overlap with other segments, keep it
segments.append((d, b))
break
else:
# No overlap with any segment in this mapping, keep as is
processed.append((a, b))
segments = processed
# We are done, the minimum possible final value will be the minimum start
# of the final segments after conversion
return min(s[0] for s in segments)
A careful reader may have noticed the for: ... else:
construct above: the
else:
branch is only entered if no break
happens inside the for
. Since we
break on any overlap, the else:
is only entered when there is no overlap with
any segment of the mapping.
Quite nice! All that's left to do is call the above function and print the result!
segments = deque(segments)
minimum = solve(segments, mappings)
print('Part 2:', minimum)
If we really want, we could even use the function we just wrote to solve part 1:
after all, single numbers can be seen as segments of length 1
:
segments = deque((s, s + 1) for s in seeds)
minimum = solve(segments, mappings)
print('Part 1:', minimum)
As an interesting plus, my solution for today's part 2 is 84206669
, which
contains the numbers 420, 666 and 69, LOL. 10 stars and counting!
Problem statement — Complete solution — Back to top
Simple and fast problem today!
First of all, input parsing: we just have two lists of integers, so a simple
.split()
plus map()
will do after
splitting the input in lines and discarding the words at the start of each line.
fin = open(...)
lines = fin.readlines()
times = map(int, lines[0][9:].split())
dists = map(int, lines[1][9:].split())
We don't really need to make times
and dists
into lists or tuples as we will
iterate over them just once.
Now, without thinking too much about it, it seems simple enough to just run a
couple loops to solve the problem. For each pair of time t
and distance d
,
we can simply try and hold for any possible time between 1
and t
, and see if
that wins the race. The remaining time we have to run will be t
minus the time
we hold. Adding up all the times this happens gives us the number of different
ways to win the race, which we can accumulate as a product as requested.
To iterate over pairs of times
and dists
we can use
zip()
.
result = 1
for t, d in zip(times, dists):
wins = 0
for hold in range(1, t):
if hold * (t - hold) > d:
wins += 1
result *= wins
print('Part 1:', result)
That was... easy! Wasn't it?
Now we are told that we need to discard all whitespace between input times and distances, and only consider the single resulting big value for time and distance. Needless to say, these values are pretty huge. Our part 1 solution is still feasible, but would definitely benefit from some optimization.
If you think about it, the problem is quite simple, and we can express it with a single mathematical formula. We have t time to win the race, and to win we need to travel at least d distance. The time we have to run is equal to the total time of the race (t) minus the time x we hold at the start. Therefore we win if:
Where
Given the quadratic equation
$x_\text{min} = \frac{-t + \sqrt{t^2 - 4d}}{-2} = \frac{t - \sqrt{t^2 - 4d}}{2}$ $x_\text{max} = \frac{-t - \sqrt{t^2 - 4d}}{-2} = \frac{t + \sqrt{t^2 - 4d}}{2}$
Since we want a non-negative result we then know that any value of
Let's write a function to calculate the solution and directly give us the
answer. Since we are dealing with powers and square roots, the numbers we'll
calculate will be float
, but we can use math.floor()
and
math.ceil()
as needed to get integral values. Ideally we'd
want to calculate
from math import ceil, floor
def solve(t, d):
delta = (t**2 - 4*d)**0.5
xmin, xmax = (t - delta) / 2, (t + delta) / 2
return ceil(xmin) - floor(xmax) - 1
We can now simplify the loop we wrote in part 1:
answer = 1
for t, d in zip(times, dists):
wins = solve(t, d)
answer *= wins
And since all we are doing is calculating a product, we can use
math.prod()
with a generator expression to get
rid of the loop entirely:
from math import prod
answer = prod(solve(t, d) for t, d in zip(times, dists))
We can also use map()
and get rid of the generator expression (it will unpack
pairs of values from times
and dists
and pass them as two arguments for us):
answer = prod(map(solve, times, dists))
print('Part 1:', answer)
Sweet! For part 2, all we need to do is remove whitespace between the input
numbers using .replace()
and do a single calculation:
time = int(lines[0][9:].replace(' ', ''))
dist = int(lines[1][9:].replace(' ', ''))
answer = solve(time, dist)
print('Part 2:', answer)
Don't you love it when a single closed formula can solve the entire problem? I definitely do.
Problem statement — Complete solution — Back to top
Input parsing seems quite easy today: we have a list of pairs where the first
one is a string and the second an integer. We can simply
.split()
each input line, convert to int
the second element
and be done with it. We are asked to somehow sort the hands in ascending order
according to their strength, so we need a way to remember the bet associated
with each hand: we can do this with a dict
of the form {hand: bet_value}
.
Let's build it:
fin = open(...)
bets = {}
for line in fin:
hand, bet = line.split()
bets[hand] = int(bet)
Now, the strength of a hand can be calculated according to a few simple rules,
which basically only boil down to counting the number of occurrences of each
card in the hand to establish the type of the hand. For hands of the same type,
the stronger hand is the one with the highest card from left to right. For
example, given AAAKK
and AKAKA
(both full houses of aces over kings), the
first one is stronger since the second card is A
, while the second card of the
second hand is K
, which is lower than A
.
The cards we are given may be digits between 2
and 9
as well as any of the
letters TJQKA
. Given two hands that have the same counts of cards (i.e. none
is immediately higher than the other), it would be nice to split ties by simply
comparing the two hands as strings (e.g. hand_a < hand_b
).
This would work very well for ASCII digits since '0' < '1'
,
'1' < '2'
and so on until '8' < '9'
, but it does not work as for the letters
we have since e.g. 'A' < 'K'
, while we would want A
to have a higher value.
To overcome this limitation and have easy comparisons, we can simply choose
other characters for cards higher than 9
. Instead of TJQKA
, we can use
ABCDE
, since 'A' > '9'
, 'B' > 'A'
and so on.
Let's translate all the letters in the hands we have from TJQKA
to ABCDE
: we
can do this easily with str.translate()
after building a
translation table with str.maketrans()
.
tbl = str.maketrans('TJQKA', 'ABCDE')
bets = {}
for line in fin:
hand, bet = line.split()
bets[hand.translate(tbl)] = int(bet)
The above loop be simplified with the help of map()
:
for hand, bet in map(str.split, fin):
bets[hand.translate(tbl)] = int(bet)
And since all we are doing is creating a dictionary we could also use a dict comprehension.
bets = {hand.translate(tbl): int(bet) for hand, bet in map(str.split, fin)}
Okay... that might be a bit hard to read for the average Python programmer, but it does look pretty cool. I'll keep it in my solution's code.
Nice! Now the keys of the bets
are strings representing hands, and have been
translated to be easily comparable. In case of hands of the same type (e.g. two
double pairs), we can use a normal Python string comparison to pick the highest.
How can we determine the kind of a hand? Well, by counting the number of equal
cards, of course. The Counter
class from the
collections
module can do this for us easily. Once we know
the counts, it is pretty simple to establish the hand type. Let's see what
happens for each type:
Type | hand |
Counter(hand) |
Sorted frequencies |
---|---|---|---|
Five of a kind | 'AAAAA' |
{'A': 5} |
[5] |
Four of a kind | 'AAAAB' |
{'A': 4, 'B': 1} |
[4, 1] |
Full house | 'AAABB' |
{'A': 3, 'B': 2} |
[3, 2] |
Three of a kind | 'AAABC' |
{'A': 3, 'B': 1, 'C': 1} |
[3, 1, 1] |
Double pair | 'AABBC' |
{'A': 2, 'B': 2, 'C': 1} |
[2, 2, 1] |
One pair | 'AABBC' |
{'A': 2, 'B': 1, 'C': 1, 'D': 1} |
[2, 1, 1, 1] |
High card | 'ABCDE' |
{'A': 1, 'B': 1, 'C': 1, 'D': 1, 'E': 1} |
[1, 1, 1, 1, 1] |
From the above table, it should seem quite obvious that the only information we
need to establish the strength of a hand based on its type is the frequencies of
its cards. Given two cards, to know which one has the stronger type, we can
simply compare the counter frequencies in descending order! The first hand that
has a higher frequency wins. In fact, we have [5] > [4, 1]
, [4, 1] > [3, 2]
,
[3, 2] > [3, 1, 1]
and so on.
Let's write a function to calculate the strength of a given hand so that we can
later pass it as a key=
function to sorted()
. We can
first determine the strength of the hand type using the card frequencies
returned by Counter(hand)
, then, in case of tie (same type), we can look at
the cards in the hand itself.
All we need to be able to sort a collection of hands is a tuple of the form
(descending_frequencies, hand)
, where descending_frequencies
is a tuple or
list of int
, and hand
is the string representing the hand itself. The
sorted()
function will then first compare the frequencies and in case of tie
compare the hands. Given that both compare exactly the way we want with simple
Python comparison operators, this is all that's needed!
from collections import Counter
def strength(hand):
counter = Counter(hand).values()
frequencies = sorted(counter, reverse=True)
return (frequencies, hand)
We can now sort the keys of our bets
dictionary using the above function as
key=
and calculate the total as defined by the problem statement: sum the
products between the rank of each hand and its bid value. This can be easily
done with the help of enumerate()
providing start=1
.
ordered_hands = sorted(bets, key=strength)
total = 0
for rank, hand in enumerate(ordered_hands, 1):
total += rank * bets[hand]
To simplify things, we can map()
the hands to their bet
through bet.get()
, and then use sum()
plus a
generator expression to compute the total:
ordered_hands = sorted(bets, key=strength)
ordered_bets = map(bets.get, ordered_hands)
total = sum(rank * bet for rank, bet in enumerate(ordered_bets, 1))
We can also calculate the total
with map()
and
math.prod()
. In any case, we already have our answer!
total = sum(map(prod, enumerate(ordered_bets, 1)))
print('Part 1:', total)
We are now supposed to treat the J
cards as "jokers". A Joker can assume the
value of any other card, and when used in a hand, it will assume the value of
the card that maximizes the hand's strength. When breaking ties though, a J
should be considered weaker than a 2
(in other words, by itself, it is the
weakest card).
Since the Joker should be considered the weakest card when breaking ties, we
cannot use the previous translation of TJQKA
to ABCDE
anymore, but we can
use A0CDE
instead, since '0' < '2'
. Let's first translate the cards again.
We simply need to turn any B
into a 0
:
new_bets = {}
for hand, bet in bets.items():
new_bets[hand.replace('B', '0')] = bet
The above can also be simplified using a dict comprehension:
new_bets = {hand.replace('B', '0'): bet for hand, bet in bets.items()}
The initial request is the same: sort the hands in ascending order based on their strength, then calculate the total winnings. However, calculating the strength of a given hand became a bit harder, as we need to account for jokers too.
Let's see what happens in case we have jokers in different scenarios:
-
Five of a kind:
JJJJJ
, a five of a kind of jokers is simply the weakest five of a kind. -
Four of a kind: we can either have
XXXXJ
orJJJJX
. In both cases, converting the jokers to the other card will give us a five of a kind. Converting any of the jokers to any other card will just weaken the hand. -
Full house: we can either have
XXXJJ
orJJJXX
. In both cases, converting the jokers to the other card will give us a five of a kind. Again, converting any of the jokers to any other card will just weaken the hand. -
Three of a kind: we can either have
XXXYJ
orJJJXY
. The best we can get is a four of a kind: convertJ
toX
in the first case or toY
in the second case. -
Double pair: we can either have
XXYYJ
orJJXXY
. The best we can get is a full house convertingJ
toX
. -
One pair: we can either have
XXYZJ
orJJXYZ
. The best we can get is a three of a kind convertingJ
toX
in the first case or to any of the other cards in the second case. -
High card:
XYZWJ
... we can simply get a pair convertingJ
to any of the other cards.
It should be clear enough from the above examples: whenever we have one or more jokers available, the best thing to do is to make them all count as the card with the highest frequency. Doing anything else can only turn the hand into a weaker type, or at best keep it of the same type.
Following this logic, let's write another function to deal with hands that may
contain jokers (remember that we previously translated jokers to 0
to easily
compare hands of the same type).
def strength_with_joker(hand):
# Calculate frequencies for each card
counter = Counter(hand)
# Rremove jokers
jokers = counter.pop('0', 0)
# Sort frequencies in descending order
freqs = sorted(counts.values(), reverse=True)
# Convert all jokers to the card with the highest frequency
freqs[0] += jokers
return freqs, hand
The only edge case to consider is a five of a kind of jokers: in that case after
counter.pop('0', 0)
we'll be left with nothing, so frequs[0]
will fail. We
can simply detect this at the start:
def strength_with_joker(hand):
if hand == '00000':
return [5], hand
counter = Counter(hand)
jokers = counter.pop('0', 0)
freqs = sorted(counts.values(), reverse=True)
freqs[0] += jokers
return freqs, hand
The final result can be calculated in exactly the same as we did for part 1:
ordered = map(new_bets.get, sorted(bets, key=strength_with_joker))
total = sum(map(prod, enumerate(ordered, 1)))
print('Part 2:', total)
I really enjoyed this one! 14/50 stars.
Problem statement — Complete solution — Back to top
First problem of the year on a graph??? Hmm, almost.
Parsing the input is rather straightforward. First, let's extract the first line
of input representing the directions to follow, removing the final newline with
.rstrip()
:
fin = open(...)
directions = fin.readline().rstrip()
To make it easier to work with, let's also convert the list of directions in
the directions to integers: 0
for L
and 1
for R
. We'll see why this is useful
in a moment.
new_directions = []
for direction in directions:
if direction == 'R':
new_directions.append(1)
else:
new_directions.append(0)
directions = new_directions
The above can also be simplified by converting the boolean comparison result to
an integer since int(True) == 1
and int(False) == 0
:
new_directions = []
for direction in directions:
new_directions.append(int(direction == 'R'))
directions = new_directions
Let's just go one step further and convert the above to a generator expression:
directions = tuple(int(d == 'R') for d in directions)
Next, let's extract the network nodes from each of the following input lines
with simple slicing operations, since every line has exactly the
same format and node names are all 3 characters long. To represent the graph of
nodes we'll use a dict
of the form {source: (left, right)}
.
# Skip empty line
fin.readline()
g = {}
for line in fin:
# Lines are of the form: 'AAA = (BBB, CCC)\n'
src, left, right = line[:3], line[7:10], line[12:15]
g[src] = (left, right)
That could also be simplified to a single generator expression, or more precisely a dict comprehension (ok, I might be getting too comfortable with these, but whatever):
fin.readline()
g = {l[:3]: (l[7:10], l[12:15]) for l in fin}
Now we have all we need to solve the problem. Counting the steps from AAA
to
ZZZ
is merely a matter of following the directions through our graph g
.
Given the way we built g
, going from one node to the next can be done with
next_node = g[node][direction]
. All we need to do is follow the directions
,
and since we may need to repeat them multiple times, we can use
itertools.cycle()
to make our life easier:
from itertools import cycle
node = 'AAA'
steps = 0
for d in cycle(directions):
node = g[node][d]
steps += 1
if node == 'ZZZ':
break
We can also use enumerate()
starting from 1
to count
the steps. Either way, we're done for part 1.
node = 'AAA'
for d in enumerate(cycle(directions), 1):
node = g[node][d]
if node == 'ZZZ':
break
print('Part 1:', steps)
Now things get significantly more complex, and to be fair probably a tad too
much for my liking. We are told to start from all nodes ending with A
and
simultaneously follow the directions like we did for part 1 until all the nodes
we are on end with Z
.
This may seem like a simple problem at first sight, but as it turns out it's all except simple. Attempting the naïve solution (literally advancing N nodes in parallel) will take us nowhere since the number of steps required is way too large (for my input, it was in the order of 1013).
Why exactly am I saying that this is hard? Well, because the key to simplifying this problem is detecting cycles of nodes, but:
- It's hard to determine when a cycle is encountered.
- We can encounter "temporary" cycles that only loop a few times before never being seen again.
- Once in a cycle, we could encounter more than one Z-ending node per loop.
Firstly, we cannot simply tell that we are in a cycle if we reach the same node twice. Simple directions and graphs can give us false cycles where we encounter the same node multiple times before wandering off. For example, take the following input (I use shorter node names for simplicity):
LLLLRRLLRRRLL
A = (Z, A)
Z = (A, A)
Starting from A
, the nodes we'd visit would be: A, Z, A, Z, A, A, A, Z, A, A,
A, A, Z. We clearly encountered Z
multiple times: after 2 steps, after 4
steps, after 8 and after 13. It's only after this that we'll keep visiting the
same pattern again, so we have a cycle of length 13 where a node ending with Z
is encountered four times per cycle at offsets 2, 4, 8 and 13.
This is clearer if we use pairs of the form (node, i)
where i
is the index
in the list of directions to follow. We have: (A 0), (Z 1), (A 2), (Z 3), (A 4),
(A 5), (A 6), (Z 7), (A 8), (A 9), (A 10), (A 11), (Z 12). After this, we have
exhausted the list of directions (of length 13), so we go back to: (A 0), (Z 1),
(A 2), and so on.
In other words, a real loop is one where we encounter the same at the same
step in the given list of directions. Therefore, a real loop must have a
length that is a multiple of the length of the given directions
(len(directions)
).
Moreover, a situation like the following would also be completely valid:
L
A = (B, X)
B = (C, X)
C = (D, X)
D = (Z, X)
Z = (D, X)
Here the nodes we'd visit would be A, B, C, D, Z, D, Z, D, Z [...]. In this case, we initially have a "useless" chain (A, B, C) leading us into the real cycle (D, Z), so we also have an "initial offset" to account for.
This leads us into a situation similar (albeit more complex) to the one of 2020 day 13 part 2, where the problem was solvable in a purely mathematical way using the Chinese Remainder Theorem.
The problem statement does not help much, because the generic description we are given does not exclude the above possibilities. In other words, we aren't directly given any hint about special properties of the input that would simplify things for us. However, as it turns out, there definitely are special properties that hold for all inputs of today's problem and that reduce the complexity by an order of magnitude.
This isn't really the first time this has happened, and it's common for AoC puzzles. Nonetheless it left a lot of people (including me) wondering why their very simple solution, which was making a lot of assumptions about the input, would work. Here's a thread on the AoC subreddit where a lot of people people shared their ideas.
Given the above (and most importantly given the fact that I really don't have the time nor the willpower to solve the general problem), I will focus on explaining the properties of the input that make this problem simpler to solve and then verify them in my solution.
Let's use "A-nodes" to refer to the nodes ending with A
, and "Z-nodes" to
refer to nodes ending with Z
. A simpler version of the problem would include
the following constraints:
- Following the given directions, each A-node only reaches one Z-node;
- Continuing to follow the directions from such Z-node, it is guaranteed that we will loop back to it without encountering any other Z-node;
- The length of the loop is equal to the number of steps required to reach the Z-node from the A-node.
If the above assumptions hold, this means that for each A-node we have exactly one reachable Z-node, and we will encounter such Z-node once every N steps, where N is equal to the length of the loop and also equal to the number of initial steps from the A-node to the Z-node.
As we already saw, "looping back" does not simply mean encountering the same node twice, but encountering it twice exactly after a number of steps that is a multiple of the length of the given directions.
Let's write a function that given a starting A-node node finds and returns the length of the loop to the first Z-node.
def steps(g, node, directions):
directions_iter = enumerate(cycle(directions), 1)
for n1, d in directions_iter:
node = g[node][d]
if node[-1] == 'Z':
# Z-node encountered for the first time, stop here
break
# Remember this Z-node
znode = node
# Continue following the path until we find another Z-node
for loop_len, (n2, d) in enumerate(directions_iter, 1):
node = g[node][d]
if node[-1] == 'Z':
# Second Z-node found (should be the same as the first)
break
# Check assumptions:
#
# 1) Each A-node should only reach one Z-node;
# 2) Continuing from such node, we should loop back to it without
# encountering any other Z-nodes;
# 3) The length of the loop is equal to the number of steps required to
# reach the Z-node from the A-node.
#
assert node == znode
assert n1 % len(directions) == n2 % len(directions)
assert n1 == loop_len
return loop_len
Since we are assuming that the loop length is equal to the number of steps
needed to reach the Z-node in the first place, we can also use this function to
solve part 1. In fact, we are guaranteed to be able to reach ZZZ
from AAA
by
the problem statement, and since we are also verifying that only one Z-node is
reachable from any A-node, the ZZZ
node is also guaranteed to be the only node
reachable from AAA
.
Let's find the loop length for each A-node. First we need to find the A-nodes:
this can be done with a simple scan of the keys of our g
graph dictionary
using filter()
plus a lambda
or a
filtering generator expression:
a_nodes = filter(lambda node: node[-1] == 'A', g)
# Or, equivalent:
a_nodes = (node for node in g if node[-1] == 'A')
For each starting A-node, we can calculate the loop length using the function we
just wrote. For part 1 we can simply pass 'AAA'
.
steps_part1 = steps(g, 'AAA', directions)
print('Part 1:', steps_part1)
For any other A-node we can use map()
or a generator expression:
cycle_lengths = map(lambda a: steps(g, a, directions), a_nodes)
We are now essentially dealing with N cycles of different lengths: how can we
determine the number of steps needed for them to sync up? Pretty easy: just
calculate the least common multiple (LCM) of the lengths using the
math.lcm()
function (available since Python 3.9). Since this
function takes an arbitrary number of arguments (but not an iterable), we can
unpack the iterable when passing it as argument.
steps_part2 = lcm(*cycle_lengths)
print('Part 2:', steps_part2)
Quite the conundrum today, can't say I enjoyed jumping through hoops and verifying questionable input assumptions, but hey, that's what we get I guess!
Problem statement — Complete solution — Back to top
We are dealing with lists of integers. Parsing the input couldn't be easier. For
each line, we can simply .split()
and
map()
to int
, a patter we should be familiar with by now.
fin = open(...)
for line in fin:
numbers = list(map(int, line.split()))
# ...
We are told to keep calculating pairwise differences, so let's write a generator
function that does exactly this. There are different ways to do it, I will use
iter()
to turn the list into an iterator, then extract the
first number, and iterate over the next ones, yelding the differences one at a
time.
def deltas(nums):
it = iter(nums)
prev = next(it)
for n in it:
yield n - prev
prev = n
Sweet. Now, according to the rules, we can simply keep calling deltas(nums)
until all the deltas we calculate are 0
. For example:
0 3 6 9 12 15
3 3 3 3 3
0 0 0 0
Now we are told to append one more 0
to the last list (of all zeroes), and,
going backwards, find the next element of the original list. So, for example:
0 3 6 9 12 15 A --> 0 3 6 9 12 15 [18]
3 3 3 3 3 B --> 3 3 3 3 3 [3]
0 0 0 0 [0] --> 0 0 0 0 [0]
The number A above is what we are looking for. Since we added a 0
at the end
of the last list, we know that B - 3 = 0, and therefore that B = 0 + 3 = 3.
Similarly, A - 15 = B, therefore A = B + 15 = 0 + 3 + 15 = 18.
It's easy to see how A is nothing more than the sum of the last elements of each of the list we have.
Let's write a function that calculates it for us. We'll keep computing
differences until we have all zeroes, and accumulate the sum of the last
(rightmost) element each time. To check if all the elements of a list are 0
,
we can either write a good ol' loop, or we can use the all()
builtin: all(x == 0 for x in nums)
will tell us if there is any number that is
not 0
in the list.
def solve(nums):
tot = 0
while 1:
tot += nums[-1]
nums = list(deltas(nums))
if all(x == 0 for x in nums):
break
return tot
Alternatively, we can invert the check using any()
to verify
that our lis contains any non-zero value, and use it as the loop condition:
def solve(nums):
tot = 0
- while 1:
+ while any(nums):
tot += nums[-1]
nums = list(deltas(nums))
-
- if all(x == 0 for x in nums):
- break
return tot
For each input list, we now need to calculate the next number according to the rules using the above function, then sum all of themse numbers up.
total = 0
for line in fin:
numbers = list(map(int, line.split()))
total += solve(numbers)
print('Part 1:', numbers)
We now need to do a similar thing to find the previous number of each list.
Such number is found by first prepending a 0
to the last list (of all zeroes)
and working backwards.
For example:
A 10 13 16 21 30 45 --> [5] 10 13 16 21 30 45
B 3 3 5 9 15 --> [5] 3 3 5 9 15
C 0 2 4 6 --> [-2] 0 2 4 6
D 2 2 2 --> [2] 2 2 2
[0] 0 0 --> [0] 0 0
This time, we know that:
- 2 - D = 0
- 0 - C = D
- 3 - B = C
- 10 - A = B
Therefore we have:
- A = -B + 10
- = -(-C + 3) + 10
- = -(-(-D + 0) + 3) + 10
- = -(-(-(-0 + 2) + 0) + 3) + 10
- = -(-0 + 2) + 0 - 3 + 10
- = 0 - 2 + 0 - 3 + 10
- = 5
It's a little bit harder to see this time, but what we are doing now is simply
calculating the sum of the leftmost numbers of each list, inverting the sign of
the ones in odd positions. Indeed, reordering the second to last steop above, we
have 10 - 3 + 0 - 2 + 0
.
We can therefore calculate this number as easily as how we did for part 1:
simply keep adding the first number of each list, multiplied by either 1
or
-1
depending on the iteration.
def solve(nums):
tot_right = tot_left = 0
sign = 1
while any(nums):
tot_right += nums[-1]
tot_left += sign * nums[0]
sign = -sign
nums = list(deltas(nums))
return tot_right, tot_left
As simple as that, we now have a function that returns both the previous number and the next number of the original list, so we can calculate the answers for both part 1 and 2 at the same time:
total1 = total2 = 0
for line in fin:
nums = list(map(int, line.split()))
r, l = solve(nums)
total1 += l
total2 += r
print('Part 1:', total1)
print('Part 2:', total2)
Problem statement — Complete solution — Back to top
We are going off on grid problems this year it seems. This time the grid is
relatively simple, and we won't even need it for much else than extracting some
coordinates. In fact, the only thing we care about is the position of each
galaxy (#
) in the grid. As we need to find rows or columns that are empty
(composed only of empty space .
and no galaxy #
).
Let's take a top-down approach this time and leave the input parsing for later.
The problem statements asks us to alculate the taxicab distance
between all possible pairs of galaxies in the grid. The taxicab distance between
two points
Here's an example 1D input (consisting of a single row):
.##.#..#.#
Now, given that we are told that each space (.
) expands to two spaces, and
therefore each empty column in reality corresponds to two empty columns, we
cannot simply calculate the distance right away with a handful of subtractions.
We first need to figure out the real coordinates of each galaxy after the
expansion. This can be done by scanning the only row we have from left to right
and keeping track of the real column index: each time a space (.
) is crossed,
increment the index by two, while each time a galaxy (#
) is encountered,
remember its index and then increment by one.
Let's write a generator function for this:
def expand(row):
column = 0
for char in row:
if char == '#':
yield column
column += 1
else:
# Empty space expands: each '.' becomes two '.'
column += 2
Simple enough. However, we are dealing with a grid of galaxies, so while it is true that the above works for the simple 1D case, in the more general 2D case there can be more than one galaxy on each column, for example:
.##.#..#.#
##...#.##.
In order to account for this, we can build a list of counts for each column.
Let's do this with our actual input. The parsing couldn't be easier, just
.read()
everyting and use str.splitlines()
to get a
list of strings. Then, scan it and keep track of the count of galaxies on each
column. While we are at it, we can also do the same for rows, since the
situation is analogous. As usual with grids, the
enumerate()
will help iterating over both the
row/column indices and the characters at the same time.
with open(...) as fin:
grid = fin.read().splitlines()
row_counts = [0] * len(grid)
col_counts = [0] * len(grid[0])
for r, row in enumerate(grid):
for c, char in enumerate(row):
if char == '#':
row_counts[r] += 1
col_counts[c] += 1
Now that we have our counts we can easily identify empty rows/columns: those
will be the ones where the count is zero. Let's modify the expand()
function
we wrote earlier to take the counts into account. It's simple enough, for each
zero count we have a space, and for each non-zero count we will just need to
yield
the same index as many times as the count.
def expand(counts):
real_index = 0
for count in counts:
if count > 0:
# One or more galaxies, yield all of them
for _ in range(count):
yield real_index
real_index += 1
else:
# Empty space expands: each '.' becomes two '.'
real_index += 2
Now that we are able to calculate the expanded indices for both rows and columns, we need to calculate the sum of all pairwise distances. How can we do this? The simplest solution would be to iterate over all possible pairs of indices using two nested loops, but we can do better.
Let's take for example the following case:
.##.#
##...
12101 <-- counts per column
As we can see, the above gives us the column counts [1, 2, 1, 0, 1]
, which
translate to the expanded column indices [0, 1, 1, 2, 5]
(the double 1
is
because of the two galaxies on the same column).
How do we calculate the sum of all the pairwise distances (which in this case is nothing else than the sum of pairwise differences)? Simply iterate over all possible pairs. In order to avoid counting pairs twice, we can write one loop iterating from the firt to the last, and a second inner loop iterating from the first to the current position of the outer loop:
indices = [0, 1, 1, 2, 5]
total = 0
for i in range(len(indices)):
for j in range(0, i):
total += indices[i] - indices[j]
What we just calculated is:
$(1 - 0)$ $+ (1 - 1) + (1 - 0)$ $+ (2 - 1) + (2 - 1) + (2 - 0)$ $+ (5 - 2) + (5 - 1) + (5 - 1) + (5 - 0)$
If we take a closer look, we can see that the above simplifies to:
$1 - 0$ $+ 2\cdot1 - (1 + 0)$ $+ 3\cdot2 - (1 + 1 + 0)$ $+ 4\cdot5 - (2 + 1 + 1 + 0)$
In general, for a given element
Which is equal to:
The first term is the i
-th element multiplied by the number of elements
preceding it, and the second term is the sum of all the elements preceding it.
The above formula allows us to calculate the sum of pairwise differences in a single linear sweep of the values from lowest to highest. All we have to do is keep track of the index, and a partial sum of preceding values. Let's implement it!
def sum_pairwise_distances(values):
total = partial_sum = 0
for i, x in enumerate(values):
total += i * x - partial_sum
partial_sum += x
return total
As simple as that. Now, all we need to do to solve the problem is first
expand()
the row and column coordinates given the row and column counts, then
calculate the sum of pairwise distances for the rows and for the columns. Let's
write a functio to do it:
def solve(row_counts, col_counts):
return sum_pairwise_distances(expand(row_counts)) + \
sum_pairwise_distances(expand(col_counts))
And the solution is one function call away!
total = solve(row_counts, col_counts)
print('Part 1:', total)
The problem remains the same, but this time the expansion factor is 1000000,
meaning that each space .
turns into 1000000 spaces. Thankfully, we already
have everything we need to solve the problem. We are already accounting for
expansion in the expand()
function, we can simply add an additional
multiplier
to make it more generic:
-def expand(counts):
+def expand(counts, multiplier):
real_index = 0
for count in counts:
if count > 0:
for _ in range(count):
yield real_index
real_index += 1
else:
- real_index += 2
+ real_index += multiplier
The solve()
function can be modified to take a multiplier
parameter as well,
and pass it down:
def solve(row_counts, col_counts, multiplier=2):
return sum_pairwise_distances(expand(row_counts, multiplier)) + \
sum_pairwise_distances(expand(col_counts, multiplier))
And we can now solve part 2:
total2 = solve(row_counts, col_counts, 1000000)
print('Part 2:', total2)
The algorithm we wrote is already fast as is for a moderate number of galaxies
like the one we are dealing with, but it can technically be made faster by
getting rid of the expand()
function and incorporating the calculations into
the sum_pairwise_distances()
. This is what I did for my final solution.
The key takeaway is that instead of yielding the same
- The current real value
$x_i$ taking into account the expansion multiplier (the same way asexpand()
does); - The partial sum of the previous values, which can be expressed as the sum of
each previous value multiplied by its count:
$S_i = \sum\limits_{j=0}^{i-1} n_j x_j$ ; - The number of previous values, which can be expressed as a partial sum of the
counts:
$N_i = \sum\limits_{j=0}^{i-1} n_j$ .
The total then grows of space
corresponds to previous
corresponds to partial_sum
corresponds to
Doing this, we get the following simplified function that computes the result in significantly less iterations:
def sum_distances(counts, multiplier):
total = partial_sum = previous = space = 0
for n in counts:
if n:
total += n * (previous * space - partial_sum)
partial_sum += n * space
previous += n
space += 1
else:
space += multiplier
return total
Problem statement — Complete solution — Back to top
To parse the input data, we can simply .split()
each line,
keep the first string as is, then split the second string on commas (,
) and
map()
every number to int
.
fin = open(...)
for line in fin:
record, groups = line.split()
groups = tuple(map(int, groups.split(',')))
It's clear that we are dealing with repeated occurrences of the same basic problem, given one per line. There is a possibility to brute force each one, which incidentally is what I did for part 1 in my original solution in a pretty inconvenient way, but that seems like a pretty slow solution... let's think about it.
As it turns out, brute force might not be the answer, but it sure does get very close to it. In fact, let's start implementing a brute-force approach, and see what can be done to improve it.
For each record
, the goal is to try and build the groups of #
in all
possible ways that match the requested lengths (groups
). Given an input record
and the list of group lengths we are looking for, we can iteratively build the
groups one at a time, scanning the record while keeping track of the remaining
characters to check, the remaining groups we still have to complete, and the
length of the current group we are building.
The logic for each step can then be implemented as follows:
- If we encounter a
#
we can simply add1
to the current group length. - If we encounter a
.
we can be in two situations:- We were already building a group, so the current group length is positive. In such case, we can check the first of the remaining group lengths to build and verify that we correctly reached exactly that length. If so, we can continue, otherwise we made a mistake and can throw away this attempt.
- We were not building a group, so the current group length is zero. In such
case, we can just keep going forward to find the next
#
or?
where we can start building the next group.
- If we encounter a
?
we can choose to turn it either in a#
or a.
, so we simply try both, applying the logic of both the above steps.
When do we stop? Well, of course when we get to the end of the record. Once we do, there are two possibilities: either we built all groups correctly (so we can count this arrangement in the total) or we did not (so we don't count this arrangement).
If we built all groups correctly, we can be in one of two scenarios (either one is good):
- No groups left to build and we are currently not building one (current group length is zero).
- Exactly one group left to build and we just completed it (the current group length is the same as the requested one).
Now that we have all the logic, we can implement it as a recursive function:
def solve(record: str, groups: tuple[int], curlen=0) -> int:
# Did we reach the end of the record?
if not record:
if len(groups) == 0 and curlen == 0:
# We have no groups left to build and we aren't currently building
# one (we did not encounter any '#' after building the last group)
return 1
if len(groups) == 1 and curlen == groups[0]:
# We only have one group left to build and the current group we are
# building has exactly the requested length
return 1
# We either have more than one group left to build or we are building
# more groups than needed
return 0
char = record[0]
total = 0
if char == '#':
# (1.) Current group gets one more '#'
total += solve(record[1:], groups, curlen + 1)
elif char == '.':
if curlen == 0:
# (2.1) Not currently building a group and cannot start here, keep going
total += solve(record[1:], groups, 0)
elif len(groups) > 0 and curlen == groups[0]:
# (2.2) We completed this group correctly, advance to the next one
total += solve(record[1:], groups[1:], 0)
elif char == '?':
# (1.) Try turning this into a '#'
# Current group gets one more '#'
total += solve(record[1:], groups, curlen + 1)
# (2.) Try turning this into a '.'
if curlen == 0:
# (2.1) Not currently building a group and cannot start here, keep going
total += solve(record[1:], groups, 0)
elif len(groups) > 0 and curlen == groups[0]:
# (2.2) We completed this group correctly, advance to the next one
total += solve(record[1:], groups[1:], 0)
return total
The advancement to the next record character is done by simply popping one from
the front of the string (record[1:]
), and the same goes for the group lengths
(groups[1:]
).
The above function should already solve the problem, but before using it we can
simplify it a little bit. Notice how in the case of ?
we are repeating the
logic we use for #
and .
(makes sense since we can turn ?
into any of #
or .
). We can simply move the check for ?
in the two if
statements above
and get rid of the duplicated logic:
def solve(record: str, groups: tuple[int], curlen=0) -> int:
# ... unchanged ...
- if char == '#':
+ if char == '#' or char == '?':
# (1.) Current group gets one more '#'
total += solve(record[1:], groups, curlen + 1)
- elif char == '.':
+ if char == '.' or char == '?':
if curlen == 0:
# (2.1) Not currently building a group and cannot start here, keep going
total += solve(record[1:], groups, 0)
elif len(groups) > 0 and curlen == groups[0]:
# (2.2) We completed this group correctly, advance to the next one
total += solve(record[1:], groups[1:], 0)
- elif char == '?':
- # (1.) Try turning this into a '#'
- # Current group gets one more '#'
- total += solve(record[1:], groups, curlen + 1)
-
- # (2.) Try turning this into a '.'
- if curlen == 0:
- # (2.1) Not currently building a group and cannot start here, keep going
- total += solve(record[1:], groups, 0)
- elif len(groups) > 0 and curlen == groups[0]:
- # (2.2) We completed this group correctly, advance to the next one
- total += solve(record[1:], groups[1:], 0)
return total
The solution we just built looks like brute force, but has a couple of smart
checks that make it avoid bad arrangements early on. For example, when
encountering a .
, we only advance if we aren't building a group
(curlen == 0
) of if we matched the current requested group length. Going
forward unconditionally would just mean reaching a bad final arrangement. This
already cuts down the search space a little bit.
We can implement one additional check to make it a little bit smarter: if we
ever find ourselves in a situation where the current group length is higher than
the next requested group length, or if we finish building the requested groups
but encounter any #
, we can simply stop. This will prune away even more bad
arrangements early on:
def solve(record: str, groups: tuple[int], curlen=0) -> int:
# ... unchanged ...
if len(groups) > 0 and curlen > groups[0]:
# Building a group that is too long
return 0
if len(groups) == 0 and curlen > 0:
# Building a group when we shouldn't be building any more
return 0
char = record[0]
total = 0
# ... unchanged ...
We can now run the solve()
function for each line of input and accumulate the
total to get the answer we are looking for:
total = 0
for line in fin:
records, groups = line.split()
groups = tuple(map(int, groups.split(',')))
total += solve(records, groups)
print('Part 1:')
The problem remains the same, but each record now needs to be quintupled in
length: the string representing the record is repeated five times and four ?
are added in between. The group lengths also need to be extended by simply
copying them five times.
Okay, we already have a solution, let's run it. We can use simple multiplication
to repeat the groups
, while for the records we can use
str.join()
to add ?
characters between five copies of the
original string.
total1 = total2 = 0
for line in fin:
records, groups = line.split()
groups = tuple(map(int, groups.split(',')))
total1 += solve(records, groups)
records = '?'.join([record] * 5)
groups = groups * 5
total2 += solve(records, groups)
print('Part 1:', total1)
print('Part 2:', total2)
Well, that couldn't have been easier! Shortest part two of the year??? Well no, not really. We could stare at the terminal for ages, but the above script will not terminate anytime soon. Indeed, while our "smart" brute force solution works fine for small inputs, the problem complexity grows exponentially with the input length, and there now are simply too many possible arrangements to test extensively!
As it turns out though, we are very close to the solution. In fact, we are so close that all we need to add is two lines of code... but let's see why first.
To drastically improve the current solution we need to realize one thing: if we ever encounter the same situation more than once, we are doing unneeded work, as we already previously computed the answer.
As an example, let's check the following input:
record: ??...??...### | groups: 1,1,3 | curlen: 0
We have two ways to compose the first group:
#....??...###
.#...??...###
In both cases, we continue advancing and reach the same scenario:
record: ..??...### | groups: 1,3 | curlen: 0
record: ..??...### | groups: 1,3 | curlen: 0
...??...### 1,3
This now means that we will try to solve the same problem twice. Needless to say, solving the same subproblem multiple times is clearly not optimal. If we can figure out a way to remember partial solutions we can check each time if a solution already exists and avoid any further (useless) computation.
As it turns out, this can be easily achieved using a dictionary as a cache! The
dictionary keys will be tuples of the form (record, groups, curlen)
and the
dictionary values will be the solution for that specific input. At the start,
before doing any calculation, we can check the dictionary to see if we already
have a solution, and if so return it right away. At the end of any calculation,
before returning the solution we'll store the calculated total
in the
dictionary so that it can be used later.
The changes to our solve()
function are minimal:
CACHE = {}
def solve(record: str, groups: tuple[int], curlen=0):
key = (record, groups, curlen)
if key in CACHE:
# Result was already calculated
return CACHE[key]
# ... actual computation here ...
# Save result to reuse it later
CACHE[key] = total
return total
Congrats, we just solved a problem using dynamic programming! Nothing fancy, we just needed to add some memoization.
What I just wrote above can be achieved with a single line of code (in case all
the function arguments are hashable, which is our case) using the
functools.cache
(since Python 3.9) or
functools.lru_cache
decorators.
from functools import cache
@cache
def solve(record, groups, curlen=0):
# ... unchanged ...
The loop we wrote earlier now simply works as is without any change. However,
both the @cache
and @lru_cache
decorators provide a way to clear the cache
of the function, which can be done through func.cache_clear()
. We don't
necessarily need to do it, but since technically each line of input is a
different problem, we can clear the cache each time to avoid wasting memory
caching unneeded solutions from previous problems.
total1 = total2 = 0
for line in fin:
records, groups = line.split()
groups = tuple(map(int, groups.split(',')))
total1 += solve(records, groups)
records = '?'.join([record] * 5)
groups = groups * 5
total2 += solve(records, groups)
print('Part 1:', total1)
print('Part 2:', total2)
24/50 stars and counting!
Problem statement — Complete solution — Back to top
Seems like we are dealing with reflections today. The input is a series of
rectangular character grids separated by empty lines. It should be simple to
parse it and transform each grid into a list of strings (rows): just a couple of
split operations will do, and str.splitlines()
comes in
handy.
fin = open(...)
raw_grids = fin.read().split('\n\n')
for raw_grid in raw_grids
grid = raw_grid.splitlines()
Since all we are doing is calling .splitlines()
for each grid, we can use
map()
instead:
for grid in map(str.splitlines, raw_grids):
# ... solution code here ...
OK... now, what do we do with those? Well, we find reflections of course! The problem statement seemed a bit hard to understand today, but in short, all we need to do for each grid is find the horizontal line at which we get a reflection of all the rows either from the bottom or the top (whichever is closer), and the same goes for vertical lines.
As an example, take the following grids:
#....#..# ###.##.#.#
..##..### .#..... ##..###.#. v.#..##.###v
v#####.##.v ####..# .#....#.#. ^.#..##.###^
^#####.##.^ v# # ##.v .#.#.##..# ###.##.#.#
..##..### ^# # ##.^ v# ## #.##.v ...###.#.#
#....#..# ####..# ^# ## #.##.^ #.#.#.#.#.
The horizontal reflection lines for these grids are in between the rows marked
with v
and ^
. Not all the rows must be reflected, in fact, it is enough for
only one set of rows from either the top or the bottom to be fully reflected,
and the others can be ignored. In fact, the only "perfect" reflection happening
above is in the first grid. The same reasoning applies to vertical reflections.
For all the grids we have we need to find reflections and count the rows above (for horizontal reflections) or to the left (for vertical reflections). How can we do this? The grids are small enough that there is no need to come up with insanely optimized solutions: a simple scan through the grid choosing all possible sizes for the reflection will do just fine.
Let's write a function to solve this for horizontal reflections, since it's
easier to work with rows (they are simple strings) than columns. The maximum
size of a horizontal reflection (expressed in number of rows) is h // 2
, where
h
is the height of the grid. For each possible reflection size from 1
to
h // 2
, we can either have a reflection at the top or at the bottom of the
grid. It will be enough to make a couple of slice operations and comparisons to
detect it.
Once we have two chunks of rows obtained through slicing (e.g. grid[a:b]
and
grid[c:d]
), to check whether they are a reflection of each other we can simply
reverse one of the two ([::-1]
) and compare it with the other. Since comparing
two lists means comparing their contents, we'll get a match only if both lists
have the same length and contain equal strings.
def find_reflections(grid):
height = len(grid)
for size in range(1, height // 2 + 1):
# Check for reflection of the top size rows
top = grid[:size]
bottom = grid[size:2 * size]
if top == bottom[::-1]:
return size
# Check for reflection of the bottom size rows
top = grid[height - 2 * size:height - size]
bottom = grid[height - size:]
if top == bottom[::-1]:
return height - size
# No horizontal reflection found
return 0
Technically, doing bottom = grid[size:2 * size]
and then bottom[::-1]
immediately after performs a copy twice. We could simplify this to a single
slice doing grid[2 * size - 1:size - 1:-1]
, though it's a bit hard to read.
The same goes for the second bottom slice operation which can be expressed as
grid[height - 1:height - size - 1:-1]
. I always found 3-argument slices pretty
weird and counter-intuitive. In any case, uNfortunately we are always performing
a copy when slicing... that's a shortcoming of the language that does not allow
to perform non-copying slices, but again since grids are small this is no big
deal.
What about vertical reflections? Columns seem much more annoying to deal with. A clever way to get around the problem is transposing the grid and checking for horizontal reflections again. This isn't in general the best solution, but for small lists like the ones we are dealing with it's no big deal. Let's write a function to transpose a grid (list of strings):
def transpose(grid):
newgrid = []
for c in range(len(grid[0])):
row = []
for r in range(len(grid)):
row.append(grid[r][c])
newgrid.append(''.join(row))
return newgrid
The above can be simplified quite a lot, first by replacing the inner loop with a generator expression, since all we are doing is building a list:
def transpose(grid):
newgrid = []
for c in range(len(grid[0])):
row = [grid[r][c] for r in range(len(grid))]
newgrid.append(''.join(row))
return newgrid
We are now iterating over all the rows of the grid at the same time, each
iteration extracting the c
-th element of each of them. This is exactly what
the zip()
already does...
def transpose(grid):
newgrid = []
for row in zip(*grid):
newgrid.append(''.join(row))
return newgrid
And finally, we can do away with the loop by using map()
to
join the rows into strings:
def transpose(grid):
return list(map(''.join, zip(*grid)))
Okay, it took more time to write a simple grid transposition than solving the actual problem :') but we are done. Now it's only a matter of applying the same function to all the grids in our input. Continuing from the initial input parsing code we wrote, we now have:
total = 0
for grid in map(str.splitlines, raw_grids):
total += 100 * find_reflections(grid)
total += find_reflections(transpose(grid))
print('Part 1:', total)
Onto part 2!
We are now told that in addition to the reflections we just found, there also are additional reflections happening, where all cells except one match. This means that looking for reflections, we will not get a match by simply comparing the two sides like we are doing now. For example:
v###.#.##.v
^#####.##.^
The minimal grid shown above has a horizontal reflection where all cells match
except one (the first .
on the first row). If we turn that .
into a #
we
get a real reflection.
We could try all possible combinations by flipping each cell once and checking for reflections, but that'd be awfully slow (and also awful in general). What we can do instead is count the number of differences after extracting the two sides of the reflection. If there are zero different cells we have a reflection that is good for part 1, otherwise if there is exactly one we have a reflection that is good for part 2.
Let's write a function to count the number of differences given two lists of
strings. We will assume that the second list has already been flipped over so
that we can easily iterate over both lists using zip()
:
def count_differences(a, b):
diff = 0
for linea, lineb in zip(a, b):
for chara, charb in zip(linea, lineb):
if chara != charb:
diff += 1
return diff
We could theoretically stop counting whenever we reach a diff
higher than 1
since we won't consider that a reflection in any case. Furthermore, the inner
loop can be simplified with a sum()
plus a generator
expression:
def count_differences(a, b):
diff = 0
for linea, lineb in zip(a, b):
diff += sum(chara != charb for chara, charb in zip(linea, lineb))
if diff > 1:
break
return diff
Good. We can now modify the find_reflections()
function that we wrote for part
1 to use count_differences()
and return both lines of reflections.
def find_reflections(grid):
height = len(grid)
perfect = imperfect = 0
for size in range(1, height // 2 + 1):
# Check for reflection of the top size rows
top = grid[:size]
bottom = grid[size:2 * size]
diff = count_differences(top, bottom[::-1])
if diff == 0:
perfect = size
elif diff == 1:
imperfect = size
# Check for reflection of the bottom size rows
top = grid[height - 2 * size:height - size]
bottom = grid[height - size:]
diff = count_differences(top, bottom[::-1])
if diff == 0:
perfect = height - size
elif diff == 1:
imperfect = height - size
return perfect, imperfect
We can also stop as soon as we find both the perfect and the imperfect reflection:
def find_reflections(grid):
height = len(grid)
perfect = imperfect = 0
for size in range(1, height // 2 + 1):
# Check for reflection of the top size rows
top = grid[:size]
bottom = grid[size:2 * size]
diff = count_differences(top, bottom[::-1])
if diff == 0:
perfect = size
elif diff == 1:
imperfect = size
+
+ if perfect and imperfect:
+ break
# Check for reflection of the bottom size rows
top = grid[height - 2 * size:height - size]
bottom = grid[height - size:]
diff = count_differences(top, bottom[::-1])
if diff == 0:
perfect = height - size
elif diff == 1:
imperfect = height - size
+
+ if perfect and imperfect:
+ break
return perfect, imperfect
And finally, we can integrate both part 1 and part 2 calculations in the same main loop:
total1 = total2 = 0
for grid in map(str.splitlines, grids):
perfect, imperfect = find_reflections(grid)
total1 += 100 * perfect
total2 += 100 * imperfect
perfect, imperfect = find_reflections(transpose(grid))
total1 += perfect
total2 += imperfect
print('Part 1:', total1)
print('Part 2:', total2)
Problem statement — Complete solution — Back to top
Another day dealing with rectangular character grids. This time we need to transform the grid we are given though, so that's going to be fun.
Parsing the input is a walk in the park, just read everything and split the lines into a list, and while we are at it also calculate the grid height and width for later:
with open(...) as fin:
grid = fin.read().splitlines()
height, width = len(grid), len(grid[0])
We have two main kinds of cells in our grid: cubic rocks #
, which stay fixed
in place, and spherical rocks O
, which we need to move around. It seems easier
to represent things using sparse matrices instead of a
single conventional matrix (the grid
we just created above).
Let's create two sparse matrices, one for fixed rocks and one for movable rocks.
We'll use sets of coordinates as sparse matrices since we don't really need to
associate coordinates with any kind of value. As usual,
enumerate()
comes in handy to easily iterate over both
cells and grid coordinates.
fixed = set()
movable = set()
for r, row in enumerate(grid):
for c, char in enumerate(row):
if char == '#':
fixed.add((r, c))
elif char == 'O':
movable.add((r, c))
We now need to move all the movable rocks (O
) up, sort of like swiping UP in
the game 2048. We can do this by creating a new sparse matrix and
iterating over the coordinates in movable
in ascending order (since the rocks
that need to be moved first are the ones at the top). We can use the
sorted()
built-in to do this.
For each movable rock we have, starting from its original position, move one
cell up at a time, and stop when the top is reached (r < 0
) or at the first
rock encountered, then add the new coordinates to the new movable matrix. We can
either encounter rocks that are fixed or that are in the new matrix.
def move(fixed, movable):
new_movable = set()
for r, c in sorted(movable):
# Go up one cell at a time until we encounter another (fixed or movable) rock
r -= 1
while r >= 0 and (r, c) not in fixed and (r, c) not in new_movable:
r -= 1
# Given that the above condition is checked only after decrementing r,
# we will always be one step further up than needed, so use r + 1 here
new_movable.add((r + 1, c))
return new_movable
Since we only care about the ordering of the r
coordinate, we can use a key=
function to extract only that coordinate and speed the sorting up a little bit.
The operator.itemgetter()
function seems a bit
cleaner than writing a lambda
function in this case:
- for r, c in sorted(movable):
+ for r, c in sorted(movable, key=itemgetter(0)):
Secondly, we are doing two membership checks at a time: in fixed
and
in new_movable
. We can simplify things by starting with
new_movable = fixed.copy()
and checking only for in new_movable
. At the end
of the function we can then remove all the fixed coordinates at once with a
set difference. This may seem slower, but in this case the
operations of copying a set and computing a set difference, which are
implemented in native code (here for the
curious), are way faster than all the double membership checks we are doing in
Python code.
def move(fixed, movable):
- new_movable = set()
+ new_movable = fixed.copy()
for r, c in sorted(movable, key=itemgetter(0)):
r -= 1
- while r >= 0 and (r, c) not in fixed and (r, c) not in new_movable:
+ while r >= 0 and (r, c) not in new_movable:
r -= 1
new_movable.add((r + 1, c))
- return new_movable
+ return new_movable - fixed
We almost have all we need. After applying the move()
function above to our
set of movable rocks, we then need to compute the sum of their distance from the
bottom of the grid:
new_movable = move(fixed, movable)
total = 0
for r, _ in new_movable:
total += height - r
Since all we are doing is summing things up, we can simplify the above using the
classic map()
+ sum()
combo with a
lambda
:
new_grid = move(fixed, movable)
total = sum(map(lambda rc: height - rc[0], new_grid))
print('Part 1:', total)
We now need to move the movable rocks in all four directions: first up, then left, then down, and finally right. These 4 consecutive move operations are called a "spin", and we need to perform 1000000000 (one billion!) spins before calculating the final sum.
First of all: we know how to move rocks up, as we just wrote a function to do
it, but what about down, left, and right? Well, we don't necessarily need to
complicate things by writing a new function or changing the current one.
Instead, we can simply rotate our coordinates and apply the same move()
function each time.
Rotating the grid 90 degrees clockwise and moving everything up is equivalent to moving everything left, the only difference is that we then need to rotate it back 90 degrees counter-clockwise after the movement. However, since for each "spin" we need to move things 4 times (once per direction), we can simply rotate 90 degrees clockwise each time, and end up in the same direction as we started.
Writing a rotation function is straightforward: row indices become column indices, and column indices become row indices (but from the bottom, not the top, so we need to invert them). Let's write a function to rotate our sparse matrices:
def rotate_90(coords, height):
rotated = set()
for r, c in coords:
rotated.add((c, height - r - 1))
return rotated
We can also simplify the above down to a single generator expression:
def rotate_90(coords, height):
return set((c, height - r - 1) for r, c in coords)
Okay, now to the real problem: needless to say, we won't get anywhere with a literal implementation that spins one billion times. That's the equivalent of 4 billion rotations and 4 billion movements. It's going to take ages. We could maybe get away with it in a compiled language such as C/C++/Rust/Go, but nonetheless, there must be a smarter solution!
What if, after some arbitrary number of spins, we find ourselves in the same configuration as the start? That would mean that the possible configurations are only a finite amount, and that we can only cycle through them. Once a cycle is found, a bit of math is all that's needed to determine the final result.
Let's write a function that keeps applying spins until a cycle is identified. In
order to check if we already saw a configuration, and at which iteration, we can
use a dictionary of the form {seen_set: iteration_number}
. As we are working
with sets, and sets are not hashable (as they are mutable), we will need to turn
any set into an immutable frozenset
first before adding it as
key in a dictionary.
def spin(fixed, movable, height, width):
seen = {frozenset(movable): 0}
for i in range(1, 1000000000 + 1):
# Move up, rotate
movable = rotate_90(move(fixed, movable), height)
fixed = rotate_90(fixed, height)
# Move left, rotate
movable = rotate_90(move(fixed, movable), width)
fixed = rotate_90(fixed, width)
# Move down, rotate
movable = rotate_90(move(fixed, movable), height)
fixed = rotate_90(fixed, height)
# Move right, rotate
movable = rotate_90(move(fixed, movable), width)
fixed = rotate_90(fixed, width)
# Freeze the set and check if we already saw it
key = frozenset(movable)
if key in seen:
break
# Didn't see this set yet, take note of the current iteration
seen[key] = i
# ...
The above function can be simplified a bit, but we first need to figure out how to compute the final result once we identify a cycle. We can sketch the situation up with some ASCII-art:
Iteration 0 1 2 3 4 5
Configuration A --> B --> C --> D --> E --> F
^ |
| |
+-----------------+
In the above example, at iteration 6 we find ourselves with the same
configuration (C
) as iteration 2. At this point, we are back at the start of
the cycle, and if we keep going we'll see D
, E
... and so on.
To find the iteration at which the cycle starts, we can simply check the values
in seen
. The cycle length is then simply the current iteration minus the start
iteration (in the above example it's 4). The number of iterations remaining is
one billion minus the cycle start iteration. Once we have these values, we can
calculate the final step (inside the cycle) with a modulo operation.
Let's complete the function with this in mind. We'll return the final solution right away.
def spin(fixed, movable, height, width):
seen = {frozenset(movable): 0}
for i in range(1, 1000000000 + 1):
for h in (height, width, height, width):
movable = rotate_90(move(fixed, movable), h)
fixed = rotate_90(fixed, h)
key = frozenset(movable)
if key in seen:
break
seen[key] = i
cycle_start = seen[key]
cycle_len = i - cycle_start
remaining = iterations - cycle_start
final = remaining % cycle_len + cycle_start
# Find the configuration corresponding to the final iteration
for key, i in seen.items():
if i == final:
break
# Calculate final result as we did for part 1
return sum(map(lambda rc: height - rc[0], key))
There are a couple of optimizations to apply here:
- As fixed rocks don't move, we will only ever have 4 different
fixed
configurations possible (one per 90-degree rotation), therefore we can pre-calculate them. - Instead of iterating over
seen
to find the configuration corresponding to thefinal
iteration, we can [ab]use it to also store the inverse mapping{i: key}
. After all, there is no risk of confusing things as one is an integer and the other is afrozenset
, and storing a reference is a lightweight operation that does not perform a copy.
def spin(fixed, movable, height, width):
seen = {frozenset(movable): 0}
+
+ cache = [(fixed, height)]
+ for _ in range(3):
+ fixed = rotate_90(fixed, height)
+ cache.append((fixed, width))
+ height, width = width, height
for i in range(1, 1000000000 + 1):
- for h in (height, width, height, width):
- movable = rotate_90(move(fixed, movable), h)
- fixed = rotate_90(fixed, h)
+ for fixed, height in cache:
+ movable = rotate_90(move(fixed, movable), height)
key = frozenset(movable)
if key in seen:
break
seen[key] = i
+ seen[i] = key
cycle_start = seen[key]
cycle_len = i - cycle_start
remaining = iterations - cycle_start
final = remaining % cycle_len + cycle_start
- # Find the configuration corresponding to the final iteration
- for key, i in seen.items():
- if i == final:
- break
+ key = seen[final]
# Calculate final result as we did for part 1
return sum(map(lambda rc: height - rc[0], key))
We can now take our shiny spin()
function for a spin and solve part 2:
total = spin(fixed, movable, height, width)
print('Part 2:', total)
Problem statement — Complete solution — Back to top
Today we need to implement a (rather simple) custom
hash function. The problem asks us to work with ASCII values
of the input strings, so it seems appropriate to read the input in binary mode
('b'
) and work with bytes
instead of str
, since iterating over bytes
objects yields integers corresponding to the byte values, and in case of ASCII
characters those are exactly the ASCII values we want.
To parse the input we only need to .strip()
away unwanted
newlines and .split()
on commas to obtain a list of bytes
objects:
with open(..., 'rb') as fin:
strings = fin.read().strip().split(b',')
Now we can write a function to calculate the hash of an input string. The rules are very simple, so it's only a few lines of code:
def aoc_hash(s):
h = 0
for c in s:
h = ((h + c) * 17) % 256
return h
And finally, map()
each string to its hash and
sum()
up all the hashes:
total = sum(map(aoc_hash, strings))
print('Part 1:', total)
That was fast!
We now need to implement a hashmap with some... weird (to say the least) rules. Each string in the input can be in one of two forms:
key-
meaning "removekey
".key=v
meaning "set the value associated withkey
tov
".
The hashmap has 256 slots numbered from 0
to 255
that are lists. Each slot
contains entries of the form (key, value)
. The hashmap behaves as follows:
- When a key is removed from the, the corresponding
(key, value)
entry is removed from the slot corresponding to the key hash (obtained using the hashing function we just wrote for part 1). - When a key is added, it is inserted at the end of the slot corresponding to
the key hash. However, in case the key was already present in such slot, the
position of its corresponding entry remains unchanged, but the value is
replaced with the new one, so an entry
(key, oldval)
becomes(key, newval)
.
Without much thinking, we can quite literally implement the above instructions
as is. Let's create a HashMap
class just for the sake of it. We will use a
defaultdict
as the backing store of slots, just
to avoid having to deal with the insertion of empty lists when a given slot is
accessed the first time.
from collections import defaultdict
class HashMap:
def __init__(self):
self.backing_store = defaultdict(list)
In both the removal and the insertion case, the old entry needs to be removed
from the slot. The only difference is that in the insertion case, a new entry
is also inserted in the slot at the same index as the existing one (if any).
Let's implement an internal method to find the entry associated with a key in a
given slot and .pop()
it away. We'll return the index in case
the entry was found and popped, or -1
in case no entry was found.
It'd be nice to simply call slot.index(key)
, but the slots
contain entries of the form (key, v)
, so we cannot simply look for a key
, so
we'll need to iterate them manually. The enumerate()
built-in is always helpful:
class HashMap:
# ... continued from above ...
def _find_and_pop(self, slot, key):
for i, (k, _) in enumerate(slot):
if k == key:
slot.pop(i)
return i
return -1
Awesome, now we need a .remove(key)
method and an .insert(key, value)
method. In case of removal we just need to find the slot and use
self._find_and_pop()
, while in case of insertion we also need to .insert()
the new (key, value)
entry either at the same index the old entry was popped
or at the end of the slot.
Here's the code:
class HashMap:
# ... continued from above ...
def remove(self, key):
h = aoc_hash(key)
slot = self.backing_store[h]
self._find_and_pop(slot, key)
def insert(self, key, value):
h = aoc_hash(key)
slot = self.backing_store[h]
i = self._find_and_pop(slot, key)
if i != -1:
slot.insert(i, (key, value))
else:
slot.append((key, value))
It's worth noting that in general using a list/array for things such as
insertion and removal of elements right in the middle is not a good idea as the
operation generally needs to rearrange the whole list after inserting/removing
the element. Since we are doing a linear scan, a linked list
would be perfect for this. However, we are dealing with very small lists and
implementing a linked list in Python would be far slower in such case, so
let's stick to a normal list
.
Since the first two operations of .remove()
and .insert()
are the same, we
could also move them into ._find_and_pop()
, making it return both the slot and
the index:
class HashMap:
# ...
def _find_and_pop(self, key):
+ slot = self.backing_store[aoc_hash(key)]
+
for i, (k, _) in enumerate(slot):
if k == key:
slot.pop(i)
- return i
+ return slot, i
- return -1
+ return slot, -1
def remove(self, key):
- h = aoc_hash(key)
- slot = self.backing_store[h]
- self._find_and_pop(slot, key)
+ self._find_and_pop(key)
def insert(self, key, value):
- h = aoc_hash(key)
- slot = self.backing_store[h]
- i = self._find_and_pop(slot, key)
+ slot, i = self._find_and_pop(key)
if i != -1:
slot.insert(i, (key, value))
else:
slot.append((key, value))
That's about it. Now we can instantiate our HashMap
class and insert/remove keys as
requested by iterating over the input strings:
hashmap = HashMap()
for s in strings:
if b'-' in s:
key = s[:-1]
hashmap.remove(key)
elif b'=' in s:
key, value = s.split(b'=')
hashmap.insert(key, int(value))
The final value to calculatea "power" that corresponds to a weird sum over all entries present in the hashmap. For each entry, we need its value multiplied by its index in the slot (strating at 1) multiplied again by its slot number (starting at 1). Let's write a method to do this:
class HashMap:
# ... continued from above ...
def power(self):
total = 0
for slot_number, lst in self.backing_store.items():
for entry_index, (_, value) in enumerate(lst, 1):
total += (slot_number + 1) * entry_index * value
return total
The answer is now just one method call away:
total = hashmap.power()
print('Part 2:', total)
Problem statement — Complete solution — Back to top
We're on the roll with these ASCII grid problems it seems... well, we should know the drill by now, parsing the input is literally one line of code:
grid = open(...).read().splitlines()
As we already did for previous days, we can express coordinates as pairs (r, c)
of row index and column index, and directions as pairs of deltas (dr, dc)
. The
four directions we can go are up, down, left, right, and they correspond to the
deltas (-1, 0)
, (1, 0)
, (0, -1)
, (0, 1)
. Moving in any given direction
is a simple vector sum with the new coordinates being (r + dr, c + dc)
.
What about mirrors (\
and /
) though? How do they change the current
direction of travel? To extract a general rule we can simply test all
directions:
- Encountering
\
:- If we are going up
(-1, 0)
we end up going left(0, -1)
; - If we are going down
(1, 0)
we end up going right(0, 1)
; - If we are going left
(0, -1)
we end up going up(-1, 0)
; - If we are going right
(0, 1)
we end up going down(1, 0)
.
- If we are going up
- Encountering
/
:- If we are going up
(-1, 0)
we end up going right(0, 1)
; - If we are going down
(1, 0)
we end up going left(0, -1)
; - If we are going left
(0, -1)
we end up going down(1, 0)
; - If we are going right
(0, 1)
we end up going up(-1, 0)
.
- If we are going up
It looks like the \
mirror simply swaps the two deltas of the current
direction from (dr, dc)
to (dc, dr)
. The /
mirror does a similar
transformation, but the signs of the deltas are inverted: it transforms
(dr, dc)
in (-dc, -dr)
.
What about the splitters -
and |
? Well, these complicate things a little
bit, since they create additional paths to follow at the same time (at least
theoretically). Let's "ignore" them for now and write a function to explore the
grid following a single path, starting from the top left and going right. We
will consider any -
that is perpendicular to our direction as "go right" and
any |
perpendicular to our direction as "go down".
To count the number of cells we visit we can use a set of coordinates: we'll add the current coordinates to the set each step, and check its length when we are done.
How do we stop? well, the beam we are emulating can either escape the grid or end up in an infinite loop. The first case is easy to detect with a bounds check, and for the second one we can keep a set of already visited states: if we ever find ourselves on the same cell and going in a direction that we already traveled on that cell, we will know we entered a loop.
This is enough to implement an initial travel()
function:
def travel(grid):
height = len(grid)
width = len(grid[0])
r, c = 0, 0 # top left corner
dr, dc = 0, 1 # going right
seen_cells = set()
seen = set()
while 1:
# Are we out of bounds?
if not (0 <= r < height and 0 <= c < width):
# Can't possibly continue!
break
# Dif we already get here while also going in the same direction?
if (r, c, dr, dc) in seen:
# This is a loop!
break
seen.add((r, c, dr, dc))
seen_cells.add((r, c))
cell = grid[r][c]
if cell == '/':
dr, dc = -dc, -dr
elif cell == '\\':
dr, dc = dc, dr
elif cell == '-':
# If we are moving either up or down, just turn right
if dr != 0:
dr, dc = 0, 1
elif cell == '|':
# If we are moving either left or right, just turn down
if dc != 0:
dr, dc = 1, 0
r += dr
c += dc
return len(seen_cells)
The two checks right at the start of the while
loop can be moved to the loop
condition, which becomes:
while 0 <= r < height and 0 <= c < width and (r, c, dr, dc) not in seen:
# ...
Okay, but what about following multiple paths? It's not that complicated: we can
remember the paths that we still have to visit in a queue. Then, each time we
encounter a -
or |
split that is perpendicular to our direction of travel,
we keep advancing in one of the two directions and put the second (unexplored)
direction in the queue for later. Once we are done exploring, we can pop from
the queue and continue from there. This approach corresponds to a
depth-first exploration of all the paths.
Let's use a deque
for our queue. We only need to add
one more while
loop:
def travel(grid):
queue = deque([(0, 0, 0, 1)]) # r, c, dr, dc
height = len(grid)
width = len(grid[0])
seen_cells = set()
seen = set()
# While there are still paths to explore (i.e. queue is not empty)
while queue:
r, c, dr, dc = queue.pop()
while 0 <= r < height and 0 <= c < width and (r, c, dr, dc) not in seen:
seen.add((r, c, dr, dc))
seen_cells.add((r, c))
cell = grid[r][c]
if cell == '/':
dr, dc = -dc, -dr
elif cell == '\\':
dr, dc = dc, dr
elif cell == '-':
# If we are we moving either up or down, just go right
if dr != 0:
dr, dc = 0, 1
# We will explore the path that goes left later
queue.append((r, c - 1, 0, -1))
elif cell == '|':
# If we are we moving either left or right, just go down
if dc != 0:
dr, dc = 1, 0
# We will explore the path that goes up later
queue.append((r - 1, c, -1, 0))
r += dr
c += dc
return len(seen_cells)
Awesome, we can now call our function and solve the problem:
n_energized = travel(grid)
print('Part 1:', n_energized)
We now need to try and start exploring from any possible cell on the grid's perimeter and calculate the maximum possible number of encountered cells. We always start going towards the inside, so if we start on the left wall the initial direction is right, if we start on the bottom wall it's up, and so on. For corner cells, we need to test both possible start directions (e.g., for the top left cell we can either start going right or down).
Well... we already have the code, let's just move the start coordinates and
direction to the parameters of our travel()
function:
-def travel(grid):
- queue = deque([(0, 0, 0, 1)])
+def travel(grid, startr, startc, dr, dc):
+ queue = deque([(startr, startc, dr, dc)])
# ...
For simplicity, let's also pass the width and height of the grid to the
function, so that we don't have to call len()
a thousand times:
-def travel(grid, startr, startc, dr, dc):
+def travel(grid, height, width, startr, startc, dr, dc):
queue = deque([(startr, startc, dr, dc)])
- height = len(grid)
- width = len(grid[0])
# ...
The function call for part 1 can now be adjusted as follows:
height, width = len(grid), len(grid[0])
n_energized = travel(grid, height, width, 0, 0, 0, 1)
print('Part 1:', n_energized)
For part 2, we can call the function for every cell on the perimeter of the grid and calculate the maximum value returned:
best = 0
# Vertical walls
for r in range(len(grid)):
# Leftmost, start going right
best = max(best, travel(grid, height, width, r, 0, 0, 1))
# Rightmost, start going left
best = max(best, travel(grid, height, width, r, width - 1, 0, -1))
# Horizontal walls
for c in range(len(grid[0])):
# Top, start going down
best = max(best, travel(grid, height, width, 0, c, 1, 0))
# Bottom, start going left
best = max(best, travel(grid, height, width, height - 1, c, -1, 0))
print('Part 2:', best)
32 stars! I like powers of two.
Problem statement — Complete solution — Back to top
The long awaited shortest path finding problem of the year! But of course, this is not just plain boring shortest path finding, there's a twist to make things interesting.
Let's get input parsing out of the way. Once again we have a grid of ccharacters: the code to parse it is basically just muscle memory at this point...
grid = open(...).read().splitlines()
Since incidentally these characters are also digits, and we'll need them in
integer form, let's also convert everything to int
with the help of
map()
plus a generator expression, and while
we are at it, let's also compute the height and width of the grid, which will be
useful later as usual.
grid = [list(map(int, row)) for row in grid]
height = len(grid)
width = len(grid[0])
The raw information given in the form of a grid can be used to build a weighted directed graph. The only three things we need to understand are:
- How to represent a node.
- How to identify the neighbors of a node (i.e. how to identify the arcs that exit a given node).
- How to assign weights to such arcs.
When it comes to cells in a grid, the instinctive way to represent nodes is as
simple pairs of coordinates (r, c)
. However, this time it won't be enough.
Given the movement rules in the problem statement, our movement on a given cell
may be restricted to a cetain set of ditections: for example, if we reached a
cell after 3 consecutive right, the only directions we can go from there are up
or down.
Ideally, the only thing that should determine the possible paths to follow from a given node is the node itself, and not how the node was reached. If we reach the same coordinates twice without knowing where we came from, we cannot possibly know whether we should stop because we already explored all possible paths from the node, or continue because this time new paths are reachable.
Given that we can reach nodes in differnt ways, and that this will directly
influence which nodes will be reachable from them, we need more information than
just the cell coordinates to represent a node. The bare minimum needed in this
case is the coordinates plus the directions we are allowed to explore from the
node. We can represent a node using a tuple of the form
(coords, possible_directions)
(where coords
is a pair (r, c)
).
How can we identify the neighbors of a node then? The problem statement tells us that, from any given cell, we can travel between 1 and 3 consecutive steps in the same direction until we need to turn. Therefore, given a node, one way to find its neighbors is to identify all the possible coordinates that are reachable within 3 steps in a straight line in any of the allowed directions.
Take this small grid as an example:
012345
------
0|X23522
1|411111
2|211111
We start on the cell marked as X
with coordinates (0, 0)
, and we can travel
right or down. The coordinates reachable from this cell therefore are:
- Traveling right:
(0, 1)
with distance 2,(0, 2)
with distance 5 (2 + 3), and(0, 3)
with distance 10 (2 + 3 + 5). - Traveling down:
(1, 0)
with distance 4, and(2, 0)
with distance 6 (4 + 2).
What about the coordinates reachable from (0, 1)
then? Technically, we can
travel right from (0, 1)
to (0, 2)
with distance 2, for a total distance of
4 from the start, but given how we explored the grid, we already know that we
can reach (0, 2)
with distance 4 from (0, 0)
, so that would be a pointless
intermediate step to reach the same goal.
Similarly, we could go from (1, 0)
down to (2, 0)
with distance 2, for a
total distance of 4 from the start, but we already know that (2, 0)
is
reachable with distance 4 from the start!
The above reasoning gives a pretty neat way to identify neighbors: given a node identified by a pair of coordintates and possible directions of travel, travel on a straight line in any of the possible directions up to 3 steps. The neighbors of the node will then be any of the reached coordinates, and each of them will be only allowed to turn 90 degrees to continue traveling.
As another example:
012345
------
0|111111
1|3X2751
2|111111
The neighbors of the node ((1, 2), [left, right])
(the X
cell above) are:
((1, 0), [up, down])
at distance 3, with neighbors((0, 0), [left, right])
and((2, 0), [left, right])
;((1, 3), [up, down])
at distance 2, with neighbors((0, 3), [left, right])
and((2, 3), [left, right])
;((1, 4), [up, down])
at distance 9, with neighbors((0, 4), [left, right])
and((2, 4), [left, right])
;((1, 5), [up, down])
at distance 14, with neighbors((0, 5), [left, right])
and((2, 5), [left, right])
.
As you may have noticed, using this kind of representation, we only ever switch
from allowed directions [left, right]
to allowed directions [up, down]
. We
can therefore simplify things and say that the allowed directions are either
vertical or horizontal, which can be represented with a single bool
variable
tells us whether we can travel vertically or not vertically (i.e. horizontally).
Finally, arcs: an arc will simply be a tuple of the form
(neighbor_node, weight)
, and the weight of the arc, as is probably already
obvious from the above examples, is simply the sum of the numbers encountered on
the grid (including the destination cell, but not the starting cell).
Let's write a generator function that, given a pair of starting
coordinates and a direction, calculates and yields all the coordinates and the
weights moving in a straight line up to 3 steps. It seems simple enough: just
keep going for at most 3 steps in the given direction as long as you are inside
the grid, add up all the weights encountered, and yield
a tuple
(coords, weight)
each time.
def straight_line(grid, height, width, r, c, deltar, deltar):
weight = 0
for _ in range(1, 3 + 1):
r += deltar
c += deltac
if not (0 <= r < height and 0 <= c < width):
break
weight += grid[r][c]
yield ((r, c), weight)
With the help of the above code, we can now easily write another generator
function that, given a node, will yield
all its neighbors. As we already said,
a node is a tuple of the form ((r, c), can_travel_vertically)
. We have two
main cases: if we can go vertically, travel in a straight line up and down up to
3 steps, and if we can not go vertically then travel in a straight line left and
right up to 3 steps. This transltes to 4 calls to the above function.
def neighbors(grid, height, width, node):
(r, c), can_travel_vertically = node
if can_travel_vertically:
# Up
for coords, weight in straight_line(grid, height, width, r, c, -1, 0):
yield ((coords, False), weight)
# Down
for coords, weight in straight_line(grid, height, width, r, c, 1, 0):
yield ((coords, False), weight)
else:
# Left
for coords, weight in straight_line(grid, height, width, r, c, 1, -1):
yield ((coords, True), weight)
# Right
for coords, weight in straight_line(grid, height, width, r, c, 1, 1):
yield ((coords, True), weight)
The above function is definitely still siplifiable, but to avoid wasting ages I'll leave that as an exercise to the reader (or you can just check out my complete solution linked above).
Awesome! We now have:
- The definition of a node.
- The definition of an arc and its weight.
- A way to list the outgoing arcs from a node (i.e. to determine all the neighbors of a given node).
This is all we need to implement a pathfinding algorithm. In particular, since we are interested in finding the shrotest path from one node to another, we can use the famous Dijkstra's algorithm.
Copying almost verbatim from my 2019 walkthrough: as we already did in previous
years (2019 d6 p2, 2021 d15 p1), w ewill implement
Dijkstra's algorithm using a min-heap as a priority
queue to hold the nodes to visit and always pop the one
with the shortest distance from the source. The heapq
module is
exactly what we need. A defaultdict
that returns
float('inf')
(also provided by math.inf
) as the default value is also useful
to treat not-yet-seen nodes as being infinitely distant (positive floating point
infinity compares greater than any integer).
The algorithm is well-known and also well-explained in the Wikipedia page I just linked above, so I'm not going into much detail about it, I'll just add some comments to the code.
import heapq
from collections import defaultdict
from math import inf as INFINITY
def dijkstra(grid, height, width, src_coords, dst_coords):
# Set of visited nodes to avoid visiting the same node twice
visited = set()
# List of (distance, node) used as heap to choose the next node to visit
queue = []
# We'll start from the coordinates src_coords. This is the only special case
# where we can go bot horizontally and vertically, so we'll add two nodes to
# the initial queue.
queue.append((0, (src_coords, False)))
queue.append((0, (src_coords, True)))
# Dictionary of the form {node: total_distance} keeping track of the minimum
# distance from any given node to the start
distance = defaultdict(lambda: INFINITY)
# While we have nodes to visit
while queue:
# Pop the node with lowest distance from the priority queue
dist, node = heapq.heappop(queue)
# If we got to the destination, we found what we were looking for
if node[0] == dst_coords:
return dist
# If we already visited this node, skip it, proceed to the next one
if node in visited:
continue
# Mark the node as visited
visited.add(node)
# For each neighbor of this node...
for neighbor, weight in neighbors(grid, height, width, node):
# Calculate the total distance from the source to this neighbor
# passing through this node
new_dist = dist + weight
# If the new distance is lower than the minimum distance we know to
# reach this neighbor, then update its minimum distance and add it
# to the queue, as we found a "better" path to it
if new_dist < distance[neighbor]:
distance[neighbor] = new_dist
heapq.heappush(queue, (new_dist, neighbor))
# If we ever empty the queue without entering the node[0] == dst_coords
# check in the above loop, there is no path from source to destination
return INFINITY
Assuming we wrote everything correctly, all that's left to do is call the function to get our answer:
src_coords = (0, 0)
dst_coords = (height - 1, width - 1)
min_dist = dijkstra(grid, height, width, src_coords, dst_coords)
print('Part 1:', min_dist)
For part 2, we still need to find the shortest path to victory, but this time the movement rules change: whenever we move, we always have to take at least 4 steps in the same direction. Then, we can keep going for up to 6 additional steps in the same direction before we necessarily have to turn.
Well, thankfully, given the way we coded things for part 1, the changes to make
are not that many. The only function to change is straight_line()
, which
should now take into account the minimum distance of 4 and maximum of 4+6 = 10.
To make it generic enough, let's take these numbers as parameters:
-def straight_line(grid, height, width, r, c, deltar, deltar):
+def straight_line(grid, height, width, r, c, deltar, deltar, start, stop):
weight = 0
- for _ in range(1, 3 + 1):
+ for i in range(1, stop + 1):
r += deltar
c += deltac
if not (0 <= r < height and 0 <= c < width):
break
weight += grid[r][c]
- yield ((r, c), weight)
+ if i >= start:
+ yield ((r, c), weight)
The only other changes to make are simplying a matter of passing around the
start
and stop
values:
-def neighbors(grid, height, width, node):
+def neighbors(grid, height, width, node, start, stop):
(r, c), can_travel_vertically = node
if can_travel_vertically:
- for coords, weight in straight_line(grid, height, width, r, c, -1, 0):
+ for coords, weight in straight_line(grid, height, width, r, c, -1, 0, start, stop):
yield ((coords, False), weight)
- for coords, weight in straight_line(grid, height, width, r, c, 1, 0):
+ for coords, weight in straight_line(grid, height, width, r, c, 1, 0, start, stop):
yield ((coords, False), weight)
else:
- for coords, weight in straight_line(grid, height, width, r, c, 1, -1):
+ for coords, weight in straight_line(grid, height, width, r, c, 1, -1, start, stop):
yield ((coords, True), weight)
- for coords, weight in straight_line(grid, height, width, r, c, 1, 1):
+ for coords, weight in straight_line(grid, height, width, r, c, 1, 1, start, stop):
yield ((coords, True), weight)
And in the main dijkstra()
function:
-def dijkstra(grid, height, width, src_coords, dst_coords)
+def dijkstra(grid, height, width, src_coords, dst_coords, start, stop):
# ....
- for neighbor, weight in neighbors(grid, height, width, node):
+ for neighbor, weight in neighbors(grid, height, width, node, start, stop):
# ...
We can now calculate the answer for both parts with two function calls:
min_dist1 = dijkstra(grid, height, width, src_coords, dst_coords, 1, 3)
print('Part 1:', min_dist1)
min_dist2 = dijkstra(grid, height, width, src_coords, dst_coords, 4, 10)
print('Part 2:', min_dist2)
Problem statement — Complete solution — Back to top
No way, grids again? Ok, no, not really. This time we are not dealing directly with a gid, but rather only following directions to draw something on it.
Given that the input instructions to draw edges only contain directions to go up/down/left/right, what we will end up drawing a polygon that only has right angles and vertical or horizontal edges. Our task is to calculate its total area expressed in terms of points with integral coordinates that reside either inside or on the edges of the polygon.
How can this be done? There are multiple ways: the naïve approach would be to draw the entire perimeter of the polygon following the directions liteally and then determine use breadth-first search starting from any cell on its inside to count how many cells there are (including the perimeter). That would work, but would definitely not be optimal.
In fact, closed formulas exist that can solve this problem. A quick internet search for "area of a polygon given its vertices" brings up the shoelace formula, also known as "Gauss' area formula". The formula is more general and also works for edges that are not horizontal and vertical, so in our case it can also be simplified a bit. There are different variations of the formula, but the one that I think fits the scenario best is the trapezoid formula.
The core concept is easy to undestand with the help of an example. Consider the following polygon:
^ Y
| 1 8
| ########
| # # 6
| # ###########
| # 7| #
| # | 3 #
| ############# #
| 2| | # #
| | | ######
| | | 4| |5
| | | | |
O-----x------x----x----x-----> X
Ignoring the vertical segments (they have zero area below them), if we denote
the area below a segment from a to b with
The Shoelace formula in fact calculates the area exactly like this: as the
summation of the signed areas below the edges of the polygon considered in
counter-clockwise order. The area
If we iterate over the vertices of the above example pairwise, we get:
The only additional thing to consider is that we also want to include the cells
on the perimeter of the polygon itself in our area, but the formula we just
wrote will only include part of it. To overcome this, we can apply
Pick's theorem, which states that
Jonathan Paulson made a good and straight-to-the-point video explanation of this approach in his youtube video here.
Okay, we have our mind set on the goal, now onto the coding part. We want the coordinates of all the vertices and the number of points on the perimeter. We can calculate all of this in one go by parsing the input line by line.
We will .split()
each line extracting the direction and the
number of steps that follows (ignoring the last hexadecimal part for now as the
problem statement says). We will express coordinates as (r, c)
as usual (using
x
and y
would be equivalent, but I find r
and c
more intuitive).
Starting from an arbitrary point, for example (0, 0)
, we will then move each
time in the given direction for the given number of steps, and add the new
coordinates to a list of vertices. The direction, which can be one of UDLR
,
can be easily converted into deltas to apply to r
and c
with the help of a
dictionary.
fin = open(...)
dirmap = {'R': (0, 1), 'D': (1, 0), 'L': (0, -1), 'U': (-1, 0)}
vertices = []
perimeter = 0
r = c = 0
for line in fin:
direction, steps, _ = line.split()
steps = int(steps)
dr, dc = dirmap[direction]
r += steps * dr
c += steps * dc
vertices.append((r, c))
perimeter += steps
Since we will iterate over vertices pairwise, and we also want include in our
calculations the edge that goes from the last vertex to the first, we will add
(0, 0)
at the end of the list.
vertices.append((0, 0))
We now have a list of vertices and the perimeter of the polygon. Let's write a
function to calculate the inner area using the shoelace formula. To iterate over
pairs of vertices we can use itertools.pairwise()
, or
alternatively re-implement it ourselves using
itertools.tee()
as the example in the Pytho ndocumentation
shows, since pairwise()
is only available from Python 3.10.
Since we want to iterate over vertices in counter-clockwise order, but we don't
know if they are ordered clockwise or counter-clockwise, there is a chance that
the resulting area will be negative. That's no issue, we can calculate its
absolute value with abs()
.
from itertools import pairwise
def shoelace(vertices):
area = 0
for (r1, c1), (_, c2) in pairwise(vertices):
area += (c2 - c1) * r1
return abs(area)
We can now calculate the inner area and then apply Pick's theorem to get the
total area we want. We want an integer result, and the perimeter we calculate
will always be an even number (given that the path we take from the initial
vertex always brings us back to the origin), so we can use itneger division
(//
).
area = shoelace(vertices) + perimeter // 2 + 1
print('Part 1:', area)
For part 2 we are told that the hexadecimal values we ignored in part 1 actually
represent the real edges: the lowest hex digit represents the direction, which
can be 0123
for RDLU
, and the other digits represent the number of steps to
take (the length of the edge).
We already have all we need to solve the problem, we only need to feed our
function new values. We can either rewind the input and parse it again or do
everything in a single loop and keep two lists of vertices and two perimeters.
To convert hexadecimal numbers into decimal we can pass 16
(the base to
convert from) as second argument to int()
. In any case, we
can add 0123
to our dirmap
to convert those digits to the appropriate
deltas.
dirmap = {
'R': (0, 1), 'D': (1, 0), 'L': (0, -1), 'U': (-1, 0),
'0': (0, 1), '1': (1, 0), '2': (0, -1), '3': (-1, 0),
}
vertices1 = []
vertices2 = []
perimeter1 = 0
perimeter2 = 0
r1 = c1 = r2 = c2 = 0
for line in fin:
direction, steps1, hexval = line.split()
steps1 = int(steps1)
dr, dc = dirmap[direction]
r1 += steps1 * dr
c1 += steps1 * dc
direction = hexval[-2]
steps2 = int(hexval[2:-2], 16)
dr, dc = dirmap[direction]
r2 += steps2 * dr
c2 += steps2 * dc
vertices1.append((r1, c1))
vertices2.append((r2, c2))
perimeter1 += steps1
perimeter2 += steps2
vertices1.append((0, 0))
vertices2.append((0, 0))
Yeah, not the prettiest code I've seen... but why complicate things? The solutions for both parts are both one function call away:
area1 = shoelace(vertices1) + perimeter1 // 2 + 1
print('Part 1:', area1)
area2 = shoelace(vertices2) + perimeter2 // 2 + 1
print('Part 2:', area2)
Problem statement — Complete solution — Back to top
Interesting problem today, I'd say the most interesting so far this year. We need to emulate a set of "workflows" composed of rules that are linked to each other.
To the input parsing! We have two section of input separated by an empty line,
so we can just read everything, .split()
on \n\n
to get the
two sections, and then use .splitlines()
on each section.
with open(...) as fin:
raw_workflows, raw_variables = fin.read().split('\n\n')
raw_workflows = raw_workflows.splitlines()
raw_variables = raw_variables.splitlines()
Since parsing is a bit tedious, let's write two functions: one to parse the workflows and another one to parse the variable assignments.
Each line in raw_workflows
represents a single workflow of the following
form:
zzz{a>3102:vrv,a<2800:cqj,a<2999:A,bvc}
The name (zzz
) can be easily extracted by locating the first {
character
through .find()
and then slicing. The rules start right after
{
and end at the second-to-last character, so they can again be extracted by
slicing. They can then be .split()
on commas (,
) and parsed.
Each rule has the form expression:next_workflow_name
, except the last rule
which only represent the workflow name to go to if none of the previous rules
are satisfied. Each rule can therefore be .split()
on the colon (:
) through
a simple generator expression to obtain a list of pairs
[expression, next_workflow_name]
.
Each expression has the form v>num
or v<num
where v
is a
variable name and num
is an integer. We can extract these three components
with slicing, and once again use a generator expression to transform each
[expression, next_workflow_name]
pair into
(var_name, op, value, next_workflow_name)
.
Where do we store parsed workflows? Well, since each workflow include rules that
identify other workflows by name, a dictionary of the form
{name: (rules, last_name)}
seems ideal ot easy find workflows by name. That
last_name
is simply the name in the final rule, which does not have an
expression.
Here's a function that parses workflows from raw input lines to a dictionary as described above:
def parse_workflows(lines):
workflows = {}
for line in lines:
name = line[:line.find('{')]
rules = line[line.find('{') + 1:-1].split(',')
# Separate the last rule (name only) from the rest
rules, last = rules[:-1], rules[-1]
# ['expr1:name1', ...] -> [[expr1, name1], ...]
rules = (r.split(':') for r in rules)
# [[expr1, name1], ...] -> [(var_name, op, value, next_workflow_name), ...]
rules = [(exp[0], exp[1], int(exp[2:]), nxt) for exp, nxt in rules]
workflows[name] = (rules, last)
return workflows
Parsing variable assignments is simpler: remove the leading {
and trailing
}
, split each assignment on =
to have pairs of the form [var_name, value]
,
and finally convert each value to integer.
Again, the most convenient container to use is a dictionary of the form
{var_name: value}
. This is because workflow rules will refer to the variable
by name, so we need an easy way to get their value given their name. This
dictionary can be consicely built using
dict comprehension.
Since these assignments only need to be iterated over once, we can write a generator function that yields one dictionary at a time.
def parse_variables(lines):
for line in lines:
# '{a=123,b=456}' -> ['a=123', 'b=456']
assignments = line[1:-1].split(',')
# ['a=123', 'b=456'] -> [('a', '123'), ('b', '456')]
assignments = map(lambda a: a.split('='), assignments)
# [('a', '123'), ('b', '456')] -> {'a': 123, 'b': 456}
yield {a[0]: int(a[1]) for a in assignments}
Now we can parse the two input sections we have using the functions we just wrote:
workflows = parse_workflows(raw_workflows)
variables = parse_variables(raw_variables)
Our task is to try running the workflows (starting from the one named in
) with
all the given variable assignments, and for workflows that succeed (i.e. end up
referencing the dummy A
workflow) sum the variable values.
Let's write a function to run the workflows we have given one dictionary of variable assignments. The algorithm seems straightforward:
- Start with the workflow
in
; - For each rule in the workflow, test the referenced variable against the
given value, depending on the given operation (
<
or>
); - If the check passes, proceed to the next workflow identified by the rule, otherwise proceed to the next rule;
- If all the rules of the current workflow are processed without any match
(i.e. none of the tested expressions were saisfied), proceed to the
last
workflow (corresponding to the final rule that has no expression); - Continue until a workflow named
'A'
or'R'
is encountered; - In case
'A'
is encountered, return the sum of the variable values, or zero otherwise.
Here's the implementation:
def run(workflows, variables):
cur = 'in'
while cur != 'A' and cur != 'R':
rules, last = workflows[cur]
for varname, op, value, nxt in rules:
var = variables[varname]
if op == '<' and var < value:
cur = nxt
break
if op == '>' and var > value:
cur = nxt
break
else:
# No break statement encountered (i.e., none of the rules matched)
cur = last
if cur == 'A':
return sum(variables.values())
return 0
Using the above function, we can now calculate the value for all the variable
assignments we have, and sum()
everything up.
total = sum(run(workflows, v) for v in variables)
We can also use map()
instead of a generator expression if
we first bind the first argument of run
to be always workflows
using
functoos.partial()
:
total = sum(map(partial(run, workflows), variables))
print('Part 1:', total)
Now we are told to ignore the given variable assignments. We know that each of
the four variables we have (x
, m
, a
, s
) can have a value ranging from 1
to 4000 (ends included), and want to calculate how many among all the possible
value assignments are accepted by the workflows.
We clearly cannot test all the possible assignments in a reasonable time, since we have 4 ranges of 4000 values, so a total of 40004 = 256000000000000 (256 trillion) unique possible assignments. We need to architect a smarter strategy.
It may already be obvious, but the workflows we are given represent a directed graph: each workflow represents a node, and is connected to other workflows through some ordered expressions (arcs). We can advance from one node to the other testing the expressions in the given order and taking the arc corresponding to the first satisfied expression (or the final arc with no expression).
Additionally, there are no loops (otherwise part 1 would have been impossible),
so the graph is also a tree, where the root is the workflow named
'in'
, and the only two leaves are the two workflows named 'A'
and 'R'
.
We can solve the problem in a single exploration of the tree from the root down to the leaves, either using BFS or DFS. The latter is simpler to implement as a recursive function. To do this, along the way we'll keep track of ranges of possible values instead of single values.
Starting from the root workflow ('in'
) with [1, 4000] as the possible range
for all the variables, the strategy to implement is as follows:
- If we reach the
'A'
workflow, return the product of all the range sizes, which corresponds to the number of possible combinations; - Otherwise, if we reach the
'R'
workflow, return0
; - Otherwise, we need to go through the rules of the current workflow. For each
rule:
- Check the value v to test and the current range [lo, hi] of possible values for the tested variable. In case the expression is var < v, we need to discard all the values above v, so the new accepted range will be [lo, v - 1]. In case the operation is var > v, we need to discard all the values below v, so the new accepted range will be [v + 1, hi].
- Make a recursive call to know the amount of accepted values for the new ranges. Add the result of the recursive call to the total number of accepted rules.
- Update the current range for the tested variable with the opposite of the accepted range (we advance to the next rule only if the current did not match). In case of var < v we advance with [v, hi], while in case od var > v we advance with [lo, value].
- Finally, after processing all the rules, make a last recursive call to also explore the workflow corresponding to the final rule that has no associated expression using the updated ranges. Add the result of the recursive call to the total and return it.
The above algorithm essentially splits the search space in half each time a rule is evaluated: one half is accepted and continues to the next workflow, while the other half is rejected and passed on to the next rule of the current workflow.
This can be nicely visualized using a Sankey diagram, which is what was done here in this Reddit post in the AoC subreddit. Pretty neat!
We'll implement the above as a recursive function taking 3 arguments: the
workflows, the initial ranges and the current workflow name. The variable ranges
will be represented with a dictionary of the form {var_name: (lo, hi)}
,
starting with lo=1
and hi=4000
for all variables.
To update the dictionary of ranges and pass it on recursively, we can use the
bitwise OR operator (|
), which has the effect of creating a new dictionary
with the values taken from the right side of the operator overriding the
originals, as documented here. So, for example:
{'x': 1, 'y': 2} | {'x': 3}
will produce {'x': 3, 'y': 2}
.
Here's the implementation of the above:
def count_accepted(workflows, ranges, cur='in'):
if cur == 'A':
# The ranges we have are accepted, return the number of possible
# combinations, which corresponds to the product of the range sizes
product = 1
for lo, hi in ranges.values()
product *= hi - lo + 1
return product
if cur == 'R':
# The ranges we have were rejected
return 0
rules, last = workflows[cur]
total = 0
for var, op, value, nxt in rules:
lo, hi = ranges[var]
if op == '<':
# Check if this rule can match any of the values in [lo, hi]
if lo < value:
# Crate a new ranges dictionary updating the range of this
# variable from [lo, hi] to [lo, value - 1], then explore the
# next workflow with a recursive call
ranges | {var: (lo, value - 1)}
total += count_accepted(workflows, , nxt)
# If possible, update the current range with the opposite of the
# match (since we also want to explore the possibility of no match),
# so from [lo, hi] to [value, hi], and continue to the next rule
if hi >= value:
ranges[var] = (value, hi)
else:
# Pretty much the same reasoning as above...
if hi > value:
total += count_accepted(workflows, ranges | {var: (value + 1, hi)}, nxt)
if lo <= value:
ranges[var] = (lo, value)
# Also try processing the next workflow, which should be explored only if no
# rules match (we already updated all the ranges as needed)
total += count_accepted(workflows, ranges, last)
return total
The calculation of the product in case we encounter the 'A'
workflow can be
simplified using math.prod()
plus a
generator expression:
+from math import product
+
def count_accepted(workflows, ranges, cur='in'):
if cur == 'A':
# The ranges we have are accepted, return the number of possible
# combinations, which corresponds to the product of the range sizes
- product = 1
- for lo, hi in ranges.values()
- product *= hi - lo + 1
-
- return product
+ return prod(hi - lo + 1 for lo, hi in ranges.values())
# ... unchanged ...
We are once again one function call away from the solutoin, so let's get our 38th star:
accepted = count_accepted(workflows, {v: (1, 4000) for v in 'xmas'})
print('Part 2:', accepted)
Problem statement — Complete solution — Back to top
Today's problem is a fun one (well, at least for part 1). We are working with singals and want to propagate them through what is essentially a directed graph of "modules".
The input we are given consists of lines of the form src -> dst1, dst2, dst3
.
Te easiest way to represent the graph is a dictionary of the form
{node: list_of_neighbors}
, and the input provides us the key-value pairs to
fill it with.
We can represent "pulses" as booleans with True
meaning high and False
meaning low. In order to distinguish between different kind of modules, we can
use two additional dictionaries: one for "flip-flop" modules, where the values
will be the current state of each flip-flop, and one for "conjunction" modules,
where the values will be dictionaries used to remember the last state of the
inputs for each conjunction. Modules that are neighte flip-flops nor
conjunctions don't need to remember any state them.
Let's get to parsing then: separate sources from destinations by
.split()
-ing each line on the arrow ('->'
), then
.strip()
useless whitespace from both parts and .split()
again the destinations on commas (', '
) to get a list.
To determine the module type we simply check the first character: for flip-flops
we initialize their state to False
(low), while for conjunctions we'll create
a new empty dictionary (to fill later). Finally, we'll add everything to the
graph dictionary.
flops = {}
conjs = {}
graph = {}
with open(...) as fin:
for line in fin:
source, dests = line.split('->')
source = source.strip()
dests = dests.strip().split(', ')
if source[0] == '%':
source = source[1:]
flops[source] = False
elif source[0] == '&':
source = source[1:]
conjs[source] = {}
graph[source] = dests
Now that we have built the graph and recognized each flip-flop and conjunction,
we need to initialize the state of each conjunction module: for each
conjunction, we will add any module that has such conjunction in its destination
list to the conjunction's dictionary. This can be done with a simple for
loop
over the .items()
of the conjs
dictionary we just built.
for source, dests in graph.items():
for dest in dests:
if dest in conjs:
cnjs[dest][source] = False
Now the fun begins. Following the rules, we need to propagate an initial pulse
starting from the 'broadcaster'
module. This module gets a los (False
) pulse
from an initial button press, and such pulse will propagate according to the
rules in the problem statement.
We can implement a breadth-first exploration of the graph using a
deque
as FIFO queue. The elemenents to
enqueue will be tuples of the form (sender, receiver, pulse)
where sender
is
the module that sent the pulse, receiver
is the module currently receiving it,
and pulse
is True
/False
for high/low.
We'll keep going until there are elements in the queue, popping one at a time from the front and processing it, which may result in appending new elements to process to the tail of the queue. We'll also keep track of the number of low and high pulses encountered, as this is what we are asked.
Let's write a function for this: the implementation will simply follows the rules given by the problem statement.
from collections import deque
def run(graph, flops, conjs):
q = deque([('button', 'broadcaster', False)])
nhi = 0 # number of high pulses encountered
nlo = 0 # number of low pulses encountered
while q:
sender, receiver, pulse = q.popleft()
if pulse:
nhi += 1
else:
nlo += 1
if receiver in flops:
# Flip flops only react to low pulses
if pulse:
return
# When a low pulse is received, the state of the flip-flop switches
# to its opposite and a concording pulse is propagated
next_pulse = flops[receiver] = not flops[receiver]
elif receiver in conjs:
# When a pulse is received by a conjunction, the last known state of
# the sender is updated for this conjunction
conjs[receiver][sender] = pulse
# If all the last pulse seen from all the inputs of this conjunction
# was high, then a low pulse is propagated, otherwise a high pulse
next_pulse = not all(conjs[receiver].values())
elif receiver in graph:
# Neither a flip-flop nor a conjunction, propagate the pulse as is
next_pulse = pulse
else:
# This module is not connected to any other module, cannot propagate
return
# Now propagate the new pulse to all the modules connected to this one
for new_receiver in graph[receiver]:
q.append((receiver, new_receiver, next_pulse))
return nhi, nlo
To get our 39th star, we can now call the function 1000 times and calculate the product of the total number of low and high pulses observer:
tothi = totlo = 0
for _ in range(1000):
nhi, nlo = run(graph, flops, conjs)
tothi += nhi
totlo += nlo
answer = tothi * totlo
print('Part 1:', answer)
Things take an interesting turn: we need to reset all flip-flops and
conjunctions to their initial state, and then find out when will the 'rx'
module ever receive the first low pulse, in terms of number of button presses.
Theoretically, we already have the code to do this: in part 1 we had to emulate
1000 button presses, why not simply add a check inside the run()
function and
call it indefinitely until the check is hit? Well, of course, in classic Advent
of Code spirit, this brute force approach won't work. Or well, it would work, if
it weren't for the fact that the number we are looking for is insanely high, so
the solution would take a little bit too much time to run.
We need to take a look at our input to realize what's going on. A couple of visualizations posted today on AoC's subreddit are also helpful for this: one, two. In any case, a few interesting aspects stand out:
- There's only one
rx
module; - There's only one input to the
rx
module, and it's a conjunction module; - All the inputs to this conjunction module are also conjunction modules.
Let's stop and think for a second. Unfortunately we will have to make a few assumptions without which our life would be much, much harder. The good news is that these assumptions seem to hold, the bad news is that making assumptions theoretically makes the solution less general. As I take it though, this always happens for at least one or two AoC problems each year, so I'm relatively fine with it.
Consider the following:
- Assume that the three above conditions always hold (at least, they dit for my input and I assume everyone else's input);
- Let's call
$A$ the conjunction mentioned in point 2 above, which is the only input module torx
. Then, let's call$B_1, B_2, ..., B_N$ the N conjunctions mentioned in point 3 above, which are the only input modules to$A$ . - Given the behavior of conjunctions,
$A$ will send a low pulse toin
as soon as all the remembered pulses from its inputs ($B_i$ ) are high. - Each
$B_i$ will send a high pulse to$A$ every time it receives a low pulse (no matter the state, a low pulse received by a conjunction will always propagate another low pulse). - Assume that each
$B_i$ somehow periodically receives a low pulse every$P_i$ button presses (i.e., each$P_i$ runs of therun()
function we wrote). - Assume that the period of each
$B_i$ module starts at the first button press and ends at the first low pulse delivered to the module.
We now have N modules (
The N modules we have send high pulses in periods of different lengths starting at the same instant, and we want to know when they will synchronize. This will happen every $\text{lcm}(P_1,...,P_n)$ iterations. The situation is similar to the one of day 8 part 2.
The logic to use to propagate pulses is unchanged, so whatever code we write
will be very similar to the code for part 1. Let's therefore extract the pulse
propagation logic of the run()
function into an helper
generator function that we can then use for both part 1 and
part 2. This function will take the current sender
, receiver
and pulse
,
process the pulse according to the rules, and yield
all the next elements to
explore (i.e., to add the queue). It's as simple as cut and paste:
def propagate_pulse(graph, flops, conjs, sender, receiver, pulse):
if receiver in flops:
if pulse:
return
next_pulse = flops[receiver] = not flops[receiver]
elif receiver in conjs:
conjs[receiver][sender] = pulse
next_pulse = not all(conjs[receiver].values())
elif receiver in graph:
next_pulse = pulse
else:
return
for new_receiver in graph[receiver]:
yield receiver, new_receiver, next_pulse
The run()
function from part 1 can then be rewritten as follows:
def run(graph, flops, conjs):
q = deque([('button', 'broadcaster', False)])
nhi = nlo = 0
while q:
sender, receiver, pulse = q.popleft()
if pulse:
nhi += 1
else:
nlo += 1
q.extend(propagate_pulse(graph, flops, conjs, sender, receiver, pulse))
return nhi, nlo
We can now write a function to find out the rx
. Then, we
have to identify each .items()
of the
graph
we have. Since we are making assumptions, where possible, it's best to
make sure that they hold with assert
.
def find_periods(graph, flops, conjs):
periodic = set() # These are B1, B2, ... BN
for rx_source, dests in graph.items():
if dests == ['rx']:
# rx_source is A
assert rx_source in conjs
break
for source, dests in graph.items():
if rx_source in dests:
assert source in conj
periodic.add(source)
# TODO: find periods...
Cool. Now, to find the periods, we can simply count()
the iterations starting from 1
and each time do a full run. Whenever we
encounter a low pulse, if the receiver is one of the .discard()
it from our periodic
set. We will be done when
the set is empty.
from itertools import count
def find_periods(graph, flops, conjs):
# ... same as above ...
for iteration in count(1):
q = deque([('button', 'broadcaster', False)])
while q:
sender, receiver, pulse = q.popleft()
if not pulse:
if receiver in periodic:
yield iteration
periodic.discard(receiver)
if not periodic:
return
q.extend(propagate_pulse(graph, flops, conjs, sender, receiver, pulse))
Now we can simply call the above function and calculate the LCM of
all the numbers it returns using math.lcm()
. First thouh, we
have to reset the state of all the flip-flops and the conjunctions. That's not
a problem it only takes a couple of for
loops.
from math import lcm
# Reset flip-flops
for f in flops:
flops[f] = False
# Reset conjunctions
for inputs in conjs.values():
for i in inputs:
inputs[i] = False
answer = lcm(*find_periods(graph, flops, conjs))
print('Part 2:', answer)
40 stars collected, 10 more to go!
Copyright © 2023 Marco Bonelli. This document is licensed under the Creative Commons BY-NC-SA 4.0 license.