11from typing import List , Optional , Tuple
22from collections import deque
3+ from dataclasses import dataclass
4+
5+
6+ class Mem (list ):
7+ def resolve_ptr (self , mode : int , ptr : int , rptr : int ) -> int :
8+ if mode == 0 :
9+ ptr = self [ptr ]
10+ elif mode == 1 :
11+ ptr = ptr
12+ elif mode == 2 :
13+ ptr = self [ptr ] + rptr
14+ else :
15+ raise Exception (f'unknown mode: { mode } ' )
16+
17+ return ptr
18+
19+ def check_mem (self , ptr : int ) -> None :
20+ if ptr < 0 :
21+ raise Exception (f'negative pointer: f{ ptr } ' )
22+
23+ if ptr >= len (self ):
24+ self += [0 ] * ((ptr + 1 ) - len (self ))
25+
26+ def get (self , mode : int , ptr : int , rptr : int ) -> int :
27+ ptr = self .resolve_ptr (mode , ptr , rptr )
28+ self .check_mem (ptr )
29+ # return self._mem[ptr]
30+ return Addr (self , mode , ptr )
31+
32+ # def set(self, mode: int, ptr: int, rptr: int, val: int):
33+ # ptr = self.resolve_ptr(mode, ptr, rptr)
34+ # self.check_mem(ptr)
35+ # self._mem[ptr] = val
36+
37+
38+ @dataclass
39+ class Addr :
40+ mem : Mem
41+ mode : int
42+ ptr : int
43+
44+ def read (self ) -> int :
45+ return self .mem [self .ptr ]
46+
47+ def write (self , value : int ):
48+ if self .mode == 1 :
49+ raise Exception (f'immediate write: { self .ptr } : { value } ' )
50+ self .mem [self .ptr ] = value
351
452
553class Emulator :
654 def __init__ (self , program : List [int ], input : Optional [deque ] = None , output : Optional [deque ] = None ):
7- self ._mem = program
8- self ._ptr = 0
55+ self .mem = Mem (program )
956 self .input = input
57+ self .ptr = 0
58+ self .rptr = 0
1059 self .output = output
1160 self .stopped = False
1261
@@ -26,15 +75,11 @@ def __init__(self, program: List[int], input: Optional[deque] = None, output: Op
2675 6 : self ._jumpfalse ,
2776 7 : self ._lessthan ,
2877 8 : self ._equals ,
78+ 9 : self ._rel_base_offset ,
2979 }
3080
31- def _g (self , mode : int , ptr : int ) -> int :
32- if mode == 1 :
33- return self ._mem [ptr ]
34- return self ._mem [self ._mem [ptr ]]
35-
36- def _p (self , count : int , modes : List [int ]) -> List [int ]:
37- return [self ._g (modes [i ], self ._ptr + i ) for i in range (count )]
81+ def _p (self , count : int , modes : List [int ]) -> List [Addr ]:
82+ return [self .mem .get (modes [i ], self .ptr + i , self .rptr ) for i in range (count )]
3883
3984 def _parse_instruction (self , instruction : int ) -> Tuple [int , List [int ]]:
4085 op = instruction % 100
@@ -47,76 +92,82 @@ def _parse_instruction(self, instruction: int) -> Tuple[int, List[int]]:
4792
4893 def run (self ) -> None :
4994 while not self .stopped :
50- op , modes = self ._parse_instruction ( self . _mem [ self . _ptr ])
51- self . _ptr += 1
95+ if not self .step ():
96+ return
5297
53- if op not in self ._op_funcs :
54- raise Exception (f'invalid opcode: { op } ' )
98+ def step (self ) -> bool :
99+ op , modes = self ._parse_instruction (
100+ self .mem .get (1 , self .ptr , self .rptr ).read ()
101+ )
102+ self .ptr += 1
55103
56- if not self ._op_funcs [op ](modes ):
57- return
104+ if op not in self ._op_funcs :
105+ raise Exception (f'invalid opcode: { op } ' )
106+
107+ return self ._op_funcs [op ](modes )
58108
59109 def _break (self , modes : List [int ]) -> bool :
60110 self .stopped = True
61111 return False
62112
63113 def _add (self , modes : List [int ]) -> bool :
64- modes [2 ] = 1
65114 p = self ._p (3 , modes )
66- self ._ptr += 3
67- self . _mem [ p [2 ]] = p [0 ] + p [1 ]
115+ self .ptr += 3
116+ p [2 ]. write ( p [0 ]. read () + p [1 ]. read ())
68117 return True
69118
70119 def _mul (self , modes : List [int ]) -> bool :
71- modes [2 ] = 1
72120 p = self ._p (3 , modes )
73- self ._ptr += 3
74- self . _mem [ p [2 ]] = p [0 ] * p [1 ]
121+ self .ptr += 3
122+ p [2 ]. write ( p [0 ]. read () * p [1 ]. read ())
75123 return True
76124
77125 def _input (self , modes : List [int ]) -> bool :
78126 try :
79127 input = self .input .popleft ()
80128 except IndexError :
81- self ._ptr -= 1
129+ self .ptr -= 1
82130 return False
83131
84- modes [0 ] = 1
85132 p = self ._p (1 , modes )
86- self ._ptr += 1
87- self . _mem [ p [0 ]] = input
133+ self .ptr += 1
134+ p [0 ]. write ( input )
88135 return True
89136
90137 def _output (self , modes : List [int ]) -> bool :
91138 p = self ._p (1 , modes )
92- self ._ptr += 1
93- self .output .append (p [0 ])
139+ self .ptr += 1
140+ self .output .append (p [0 ]. read () )
94141 return True
95142
96143 def _jumptrue (self , modes : List [int ]) -> bool :
97144 p = self ._p (2 , modes )
98- self ._ptr += 2
99- if p [0 ] != 0 :
100- self ._ptr = p [1 ]
145+ self .ptr += 2
146+ if p [0 ]. read () != 0 :
147+ self .ptr = p [1 ]. read ()
101148 return True
102149
103150 def _jumpfalse (self , modes : List [int ]) -> bool :
104151 p = self ._p (2 , modes )
105- self ._ptr += 2
106- if p [0 ] == 0 :
107- self ._ptr = p [1 ]
152+ self .ptr += 2
153+ if p [0 ]. read () == 0 :
154+ self .ptr = p [1 ]. read ()
108155 return True
109156
110157 def _lessthan (self , modes : List [int ]) -> bool :
111- modes [2 ] = 1
112158 p = self ._p (3 , modes )
113- self ._ptr += 3
114- self . _mem [ p [2 ]] = int (p [0 ] < p [1 ])
159+ self .ptr += 3
160+ p [2 ]. write ( int (p [0 ]. read () < p [1 ]. read ()) )
115161 return True
116162
117163 def _equals (self , modes : List [int ]) -> bool :
118- modes [2 ] = 1
119164 p = self ._p (3 , modes )
120- self ._ptr += 3
121- self ._mem [p [2 ]] = int (p [0 ] == p [1 ])
165+ self .ptr += 3
166+ p [2 ].write (int (p [0 ].read () == p [1 ].read ()))
167+ return True
168+
169+ def _rel_base_offset (self , modes : List [int ]) -> bool :
170+ p = self ._p (1 , modes )
171+ self .ptr += 1
172+ self .rptr += p [0 ].read ()
122173 return True
0 commit comments