#include "aoc.h" #include typedef struct Vec2i { i32 x, y; } Vec2i; typedef enum { TABLE_NUMPAD = 0, TABLE_DIRECTIONS = 1, } Table; typedef enum { D_LEFT, D_UP, D_RIGHT, D_DOWN, D_A, NUM_DIRECTION_KEYS, } DirectionKey; typedef enum { N_0, N_1, N_2, N_3, N_4, N_5, N_6, N_7, N_8, N_9, N_A, NUM_NUMPAD_KEYS, } NumpadKey; static const Vec2i TABLES[2][128] = { [TABLE_NUMPAD] = { [N_0] = { 1, 0 }, [N_A] = { 2, 0 }, [N_1] = { 0, 1 }, [N_2] = { 1, 1 }, [N_3] = { 2, 1 }, [N_4] = { 0, 2 }, [N_5] = { 1, 2 }, [N_6] = { 2, 2 }, [N_7] = { 0, 3 }, [N_8] = { 1, 3 }, [N_9] = { 2, 3 }, }, [TABLE_DIRECTIONS] = { [D_A] = { 2, 1 }, [D_UP] = { 1, 1 }, [D_LEFT] = { 0, 0 }, [D_DOWN] = { 1, 0 }, [D_RIGHT] = { 2, 0 }, }, }; typedef DYNAMIC_ARRAY(str) Strings; static Strings NUMPAD_PATHS[NUM_NUMPAD_KEYS][NUM_NUMPAD_KEYS]; // PATHS_TABLE[a][b] gives a list of strings that each represent a path from a to b static Strings DIRECTION_PATHS[NUM_DIRECTION_KEYS][NUM_DIRECTION_KEYS]; static DirectionKey parse_direction_key(u8 ch) { switch (ch) { case '<': return D_LEFT; case '^': return D_UP; case '>': return D_RIGHT; case 'v': return D_DOWN; case 'A': return D_A; default: ASSERT(0); } } static NumpadKey parse_numpad_key(u8 ch) { if (ch >= '0' && ch <= '9') { return N_0 + (ch - '0'); } else if (ch == 'A') { return N_A; } else { NOT_REACHABLE(); } } typedef struct { Arena *arena; isize len, cap; u8 *data; } StrBuf; static inline bool is_valid_position(Table table, i32 x, i32 y) { if (table == TABLE_NUMPAD) { if (x < 0 || x > 2) return false; if (y < 0 || y > 3) return false; if (x == 0 && y == 0) return false; return true; } else if (table == TABLE_DIRECTIONS) { if (x < 0 || x > 2) return false; if (y < 0 || y > 1) return false; if (x == 0 && y == 1) return false; return true; } else { NOT_REACHABLE(); } } static str sb_copy_str(StrBuf *buf, Arena *arena) { str result = {0}; if (buf->len > 0) { result.data = ARENA_ALLOC_ARRAY(arena, u8, buf->len); result.len = buf->len; memcpy(result.data, buf->data, buf->len); } return result; } static void sb_push_str(StrBuf *sb, str s) { ASSERT(sb->arena); if (s.len > 0) { if (sb->len + s.len > sb->cap) { isize new_cap = sb->cap + MIN(64, s.len); u8 *new_data = ARENA_ALLOC_ARRAY(sb->arena, u8, new_cap); ASSERT(new_data); if (sb->len > 0) { memcpy(new_data, sb->data, sb->len); } sb->cap = new_cap; sb->data = new_data; } ASSERT(sb->len + s.len <= sb->cap); memcpy(sb->data + sb->len, s.data, s.len); sb->len += s.len; } } static void enumerate_paths(Arena *arena, Table table, i32 x, i32 y, i32 dx, i32 dy, StrBuf *acc, Strings *strings) { if (!is_valid_position(table, x, y)) return; if (dx == 0 && dy == 0) { *push(strings, arena) = sb_copy_str(acc, arena); return; } if (dx > 0) { sb_push_str(acc, STR(">")); enumerate_paths(arena, table, x + 1, y, dx - 1, dy, acc, strings); acc->len--; } else if (dx < 0) { sb_push_str(acc, STR("<")); enumerate_paths(arena, table, x - 1, y, dx + 1, dy, acc, strings); acc->len--; } if (dy > 0) { sb_push_str(acc, STR("^")); enumerate_paths(arena, table, x, y + 1, dx, dy - 1, acc, strings); acc->len--; } else if (dy < 0) { sb_push_str(acc, STR("v")); enumerate_paths(arena, table, x, y - 1, dx, dy + 1, acc, strings); acc->len--; } } static Strings enumerate_paths_by_key(Arena *arena, Table table, u8 start_key, u8 end_key) { i32 start_x = TABLES[table][start_key].x; i32 start_y = TABLES[table][start_key].y; i32 dx = TABLES[table][end_key].x - start_x; i32 dy = TABLES[table][end_key].y - start_y; StrBuf acc = { .arena = arena }; Strings strings = {0}; enumerate_paths(arena, table, start_x, start_y, dx, dy, &acc, &strings); return strings; } str sb_str(StrBuf *buf) { return (str) { buf->data, buf->len }; } static void build_sequences(Arena *arena, Table table, str code, i32 index, u8 prev_key, StrBuf *current_path, Strings *result) { if (!current_path) { current_path = ARENA_ALLOC(arena, StrBuf); current_path->arena = arena; } if (index == code.len) { *push(result, arena) = sb_copy_str(current_path, arena); return; } else { u8 key; Strings *paths = NULL; if (table == TABLE_NUMPAD) { key = parse_numpad_key(code.data[index]); ASSERT(prev_key < NUM_NUMPAD_KEYS); ASSERT(key < NUM_NUMPAD_KEYS); paths = &NUMPAD_PATHS[prev_key][key]; } else if (table == TABLE_DIRECTIONS) { key = parse_direction_key(code.data[index]); ASSERT(prev_key < NUM_DIRECTION_KEYS); ASSERT(key < NUM_DIRECTION_KEYS); paths = &DIRECTION_PATHS[prev_key][key]; } else { NOT_REACHABLE(); } ASSERT(paths); for (isize i = 0; i < paths->len; i++) { sb_push_str(current_path, paths->data[i]); sb_push_str(current_path, STR("A")); build_sequences(arena, table, code, index + 1, key, current_path, result); current_path->len -= paths->data[i].len + 1; } } } static void init_tables(Arena *arena) { for (isize i = 0; i < NUM_NUMPAD_KEYS; i++) { for (isize j = 0; j < NUM_NUMPAD_KEYS; j++) { NUMPAD_PATHS[i][j] = enumerate_paths_by_key(arena, TABLE_NUMPAD, i, j); } } for (isize i = 0; i < NUM_DIRECTION_KEYS; i++) { for (isize j = 0; j < NUM_DIRECTION_KEYS; j++) { DIRECTION_PATHS[i][j] = enumerate_paths_by_key(arena, TABLE_DIRECTIONS, i, j); } } } typedef struct CacheEntry { struct CacheEntry *next; u64 hash; str code; i32 depth; i64 value; } CacheEntry; #define CACHE_EXP 16 #define CACHE_SIZE (1ull << CACHE_EXP) #define CACHE_MASK (CACHE_SIZE - 1) typedef struct Cache { Arena *arena; CacheEntry *entries[CACHE_SIZE]; isize count; } Cache; static u64 hash_pair(str code, i32 depth) { // FNV-1a u64 hash = 14695981039346656037ull; for (isize i = 0; i < code.len; i++) { hash ^= code.data[i]; hash *= 1099511628211ull; } hash ^= depth; hash *= 1099511628211ull; return hash; } static void cache_init(Cache *cache, Arena *arena) { cache->arena = arena; memset(cache->entries, 0, sizeof(cache->entries)); } static CacheEntry * cache_get(Cache *cache, str code, i32 depth) { u64 hash = hash_pair(code, depth); u64 index = hash & CACHE_MASK; for (CacheEntry *e = cache->entries[index]; e; e = e->next) { if (e->hash == hash && e->depth == depth && str_eq(e->code, code)) { return e; } } return NULL; } static str copy_str(Arena *destination, str s) { str result = {0}; if (s.len > 0) { result.data = ARENA_ALLOC_ARRAY(destination, u8, s.len); result.len = s.len; memcpy(result.data, s.data, s.len); } return result; } static void cache_print_stats(Cache *cache) { isize total_size = 0; for (isize i = 0; i < CACHE_SIZE; i++) { for (CacheEntry *e = cache->entries[i]; e; e = e->next) { printf("Code: " STR_FMT " Depth: %d Value: %ld\n", STR_ARG(e->code), e->depth, e->value); total_size += e->code.len + sizeof(CacheEntry); } } printf("Total size occupied by cache: %ld\n", total_size); } static void cache_put(Cache *cache, str code, i32 depth, i64 value) { CacheEntry *entry = cache_get(cache, code, depth); if (entry) { entry->value = value; } else { u64 hash = hash_pair(code, depth); u64 index = hash & CACHE_MASK; CacheEntry *entry = ARENA_ALLOC(cache->arena, CacheEntry); entry->hash = hash; entry->code = copy_str(cache->arena, code); entry->depth = depth; entry->value = value; entry->next = cache->entries[index]; cache->entries[index] = entry; cache->count++; } } static i64 find_shortest_dpad(Arena temp, str code, i32 depth, Cache *cache) { ASSERT(code.len > 0); CacheEntry *cache_entry = NULL; i64 result = 0; if (depth == 0) { ASSERT(code.len > 0); result = code.len; } else if ((cache_entry = cache_get(cache, code, depth))) { ASSERT(cache_entry->value > 0); result = cache_entry->value; } else { str iter = code; str sub; while (iter.len > 0) { { str_find_result result = str_find_left(iter, STR("A")); sub = str_sub(iter, 0, result.found ? result.offset + 1 : iter.len); } iter = str_sub(iter, sub.len, iter.len); ASSERT(sub.data[sub.len - 1] == 'A'); Strings sequences = {0}; build_sequences(&temp, TABLE_DIRECTIONS, sub, 0, D_A, NULL, &sequences); i64 shortest = INT64_MAX; str shortest_str = STR(""); for (isize i = 0; i < sequences.len; i++) { ASSERT(sequences.data[i].data); i64 candidate = find_shortest_dpad(temp, sequences.data[i], depth - 1, cache); ASSERT(candidate > 0); shortest = MIN(shortest, candidate); } ASSERT(shortest > 0); result += shortest; } cache_put(cache, code, depth, result); } ASSERT(result > 0); return result; } static i64 find_shortest(Arena temp, str code, i32 depth, Cache *cache) { Strings sequences = {0}; build_sequences(&temp, TABLE_NUMPAD, code, 0, N_A, NULL, &sequences); i64 shortest = INT64_MAX; for (isize i = 0; i < sequences.len; i++) { i64 candidate = find_shortest_dpad(temp, sequences.data[i], depth - 1, cache); shortest = MIN(shortest, candidate); } return shortest; } static i64 parse_numeric_prefix(Arena *arena, str s) { isize end = 0; while (end < s.len && s.data[end] >= '0' && s.data[end] <= '9') { end++; } str prefix = { s.data, end }; return parse_i64(prefix, *arena); } static u8 NUMPAD[4][3] = { { '\0', '0', 'A' }, { '1', '2', '3' }, { '4', '5', '6' }, { '7', '8', '9' }, }; static u8 DPAD[2][3] = { { '<', 'v', '>' }, { '\0', '^', 'A' }, }; static str test(Arena *perm, Arena temp, str input, i32 depth) { if (depth == 0) { str result = copy_str(perm, input); return result; } i32 x, y; if (depth == 1) { x = 2; y = 0; } else { x = 2; y = 1; } StrBuf output = { .arena = &temp }; for (isize i = 0; i < input.len; i++) { u8 ch = input.data[i]; switch (ch) { case '>' : x++; break; case '<' : x--; break; case '^' : y++; break; case 'v' : y--; break; case 'A': { u8 symbol; if (depth == 1) { ASSERT(x >= 0); ASSERT(x < 3); ASSERT(y >= 0); ASSERT(y < 4); ASSERT(!(x == 0 && y == 0)); symbol = NUMPAD[y][x]; } else { ASSERT(x >= 0); ASSERT(x < 3); ASSERT(y >= 0); ASSERT(y < 2); ASSERT(!(x == 0 && y == 1)); symbol = DPAD[y][x]; } str s = { &symbol, 1 }; sb_push_str(&output, s); } break; default: NOT_REACHABLE(); } if (depth == 1) { ASSERT(x >= 0); ASSERT(x < 3); ASSERT(y >= 0); ASSERT(y < 4); ASSERT(!(x == 0 && y == 0)); } else { ASSERT(x >= 0); ASSERT(x < 3); ASSERT(y >= 0); ASSERT(y < 2); ASSERT(!(x == 0 && y == 1)); } } return test(perm, temp, sb_str(&output), depth - 1); } int main(int argc, char **argv) { Arena *perm = make_arena(Megabytes(2)); Arena *temp = make_arena(Megabytes(1)); Tokens lines = read_lines(perm, argv[1]); ASSERT(str_eq(test(perm, *temp, STR("^^AvvvA"), 1), STR("029A"))); ASSERT(str_eq(test(perm, *temp, STR("v<>^AAvA<^AA>A^A"), 2), STR("029A"))); ASSERT(str_eq(test(perm, *temp, STR(">^AvAA<^A>A>^AvA^A^A^A>AAvA^AA>^AAAvA<^A>A"), 3), STR("029A"))); init_tables(perm); Cache *cache = ARENA_ALLOC(perm, Cache); cache_init(cache, perm); i64 part_1 = 0; i64 part_2 = 0; for (isize i = 0; i < lines.len; i++) { str line = str_trim(lines.tokens[i]); /* printf(STR_FMT "\n", STR_ARG(line)); */ i64 prefix = parse_numeric_prefix(temp, line); i64 shortest_part_1 = find_shortest(*temp, line, 1 + 2, cache); i64 shortest_part_2 = find_shortest(*temp, line, 1 + 25, cache); part_1 += shortest_part_1 * prefix; part_2 += shortest_part_2 * prefix; } printf("%ld\n", part_1); printf("%ld\n", part_2); return 0; }