-
Notifications
You must be signed in to change notification settings - Fork 0
/
12.py
75 lines (63 loc) · 2.12 KB
/
12.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
from functools import lru_cache
sys.setrecursionlimit(50000)
print(chr(27)+'[2j')
print('\033c')
f = open('12.input', 'r')
#f = open('12.test', 'r')
lines = [x.strip() for x in f.readlines()]
RED = '\033[91m'
GREEN = '\033[92m'
RESET = '\033[0m'
@lru_cache(maxsize=None)
def solve(line: str):
condition, springs = line.split(' ')
springs = springs.strip().split(',') if springs.strip() else []
if len(springs) == 0:
if '#' not in condition:
return 1
else:
return 0
if len(condition) == 0:
return 0
if condition[-1] != '.':
condition += '.'
springs = list(map(lambda x: int(x), springs))
result = 0
if condition[0] == '#' or condition[0] == '?':
length = springs[0]
long_enough = length <= len(condition)
no_dot = '.' not in condition[:length]
no_nxt = length < len(condition) and condition[length] != '#'
if no_dot and long_enough and no_nxt:
c = condition[length+1:]
springs_str = ",".join(map(str, springs[1:]))
r = solve(c + " " + springs_str)
result += r
if condition[0] == '.' or condition[0] == '?':
springs_str = ",".join(map(str, springs))
r = solve(condition[1:] + " " + springs_str)
result += r
return result
def unfold(line: str):
conditions, springs = line.split()
return"?".join([conditions]*5) + " " + ",".join([springs]*5)
def solve2(line: str):
return solve(unfold(line))
# assert(solve('?#?#?#?#?#?#?#? 1,3,1,6') == 1)
# assert(solve('?###???????? 3,2,1') == 10)
# assert(solve('???.### 1,1,3') == 1)
# assert(solve('') == 4)
# assert(solve('????.#...#... 4,1,1') == 1)
# assert(solve('????.######..#####. 0,6,5') == 4)
# assert(solve2('?#?#?#?#?#?#?#? 1,3,1,6') == 1)
# assert(solve2('???.### 1,1,3') == 1)
# assert(solve2('????.#...#... 4,1,1') == 16)
# assert(solve2('????.######..#####. 1,6,5') == 2500)
# assert(solve2('.??..??...?##. 1,1,3') == 16384)
# assert(solve2('?###???????? 3,2,1') == 506250)
total = 0
for i, line in enumerate(lines):
res = solve2(line)
total += res
print("Total:", total)