Skip to content

Commit eb155e4

Browse files
committed
chess.pgn typing tweaks
1 parent d6fa964 commit eb155e4

1 file changed

Lines changed: 41 additions & 21 deletions

File tree

chess/pgn.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
import re
2323
import weakref
24+
import typing
2425

2526
import chess
2627

@@ -137,6 +138,8 @@ def board(self, *, _cache: bool = True) -> chess.Board:
137138
138139
It's a copy, so modifying the board will not alter the game.
139140
"""
141+
assert self.parent is not None and self.move is not None, "cannot get board of dangling GameNode"
142+
140143
if self.board_cached is not None:
141144
board = self.board_cached()
142145
if board is not None:
@@ -151,13 +154,18 @@ def board(self, *, _cache: bool = True) -> chess.Board:
151154
else:
152155
return board
153156

157+
def _move(self) -> chess.Move:
158+
assert self.move is not None, "cannot get move of dangling GameNode"
159+
return self.move
160+
154161
def san(self) -> str:
155162
"""
156163
Gets the standard algebraic notation of the move leading to this node.
157164
See :func:`chess.Board.san()`.
158165
159166
Do not call this on the root node.
160167
"""
168+
assert self.parent is not None and self.move is not None, "cannot get san of dangling GameNode"
161169
return self.parent.board().san(self.move)
162170

163171
def uci(self, *, chess960: Optional[bool] = None) -> str:
@@ -167,6 +175,7 @@ def uci(self, *, chess960: Optional[bool] = None) -> str:
167175
168176
Do not call this on the root node.
169177
"""
178+
assert self.parent is not None and self.move is not None, "cannot get uci of dangling GameNode"
170179
return self.parent.board().uci(self.move, chess960=chess960)
171180

172181
def root(self) -> "GameNode":
@@ -301,7 +310,7 @@ def mainline(self) -> "Mainline[GameNode]":
301310

302311
def mainline_moves(self) -> "Mainline[chess.Move]":
303312
"""Returns an iterator over the main moves after this node."""
304-
return Mainline(self, lambda node: node.move)
313+
return Mainline(self, lambda node: node._move())
305314

306315
def add_line(self, moves: Iterable[chess.Move], *, comment: str = "", starting_comment: str = "", nags: Iterable[int] = ()) -> "GameNode":
307316
"""
@@ -326,6 +335,8 @@ def add_line(self, moves: Iterable[chess.Move], *, comment: str = "", starting_c
326335
return node
327336

328337
def _accept_node(self, parent_board: chess.Board, visitor) -> None:
338+
assert self.move is not None, "cannot visit dangling GameNode"
339+
329340
if self.starting_comment:
330341
visitor.visit_comment(self.starting_comment)
331342

@@ -346,6 +357,8 @@ def accept(self, visitor, *, _parent_board: Optional[chess.Board] = None):
346357
Traverses game nodes in PGN order using the given *visitor*. Starts with
347358
the move leading to this node. Returns the *visitor* result.
348359
"""
360+
assert self.parent is not None and self.move is not None, "cannot visit dangling GameNode"
361+
349362
board = self.parent.board() if _parent_board is None else _parent_board
350363

351364
# First, visit the move that leads to this node.
@@ -402,12 +415,16 @@ def __str__(self) -> str:
402415
return self.accept(StringExporter(columns=None))
403416

404417
def __repr__(self) -> str:
405-
return "<{} at {:#x} ({}{} {} ...)>".format(
406-
type(self).__name__,
407-
id(self),
408-
self.parent.board().fullmove_number,
409-
"." if self.parent.board().turn == chess.WHITE else "...",
410-
self.san())
418+
if self.parent is None:
419+
return f"<{type(self).__name__} at {id(self):#x} (dangling)>"
420+
else:
421+
parent_board = self.parent.board()
422+
return "<{} at {:#x} ({}{} {} ...)>".format(
423+
type(self).__name__,
424+
id(self),
425+
parent_board.fullmove_number,
426+
"." if parent_board.turn == chess.WHITE else "...",
427+
self.san())
411428

412429

413430
GameT = TypeVar("GameT", bound="Game")
@@ -439,24 +456,25 @@ def setup(self, board: Union[chess.Board, str]) -> None:
439456
``FEN``, ``SetUp``, and ``Variant`` header tags.
440457
"""
441458
try:
442-
fen = board.fen()
459+
fen = board.fen() # type: ignore
460+
setup = typing.cast(chess.Board, board)
443461
except AttributeError:
444-
board = chess.Board(board)
445-
board.chess960 = board.has_chess960_castling_rights()
446-
fen = board.fen()
462+
setup = chess.Board(board) # type: ignore
463+
setup.chess960 = setup.has_chess960_castling_rights()
464+
fen = setup.fen()
447465

448-
if fen == type(board).starting_fen:
466+
if fen == type(setup).starting_fen:
449467
self.headers.pop("SetUp", None)
450468
self.headers.pop("FEN", None)
451469
else:
452470
self.headers["SetUp"] = "1"
453471
self.headers["FEN"] = fen
454472

455-
if type(board).aliases[0] == "Standard" and board.chess960:
473+
if type(setup).aliases[0] == "Standard" and setup.chess960:
456474
self.headers["Variant"] = "Chess960"
457-
elif type(board).aliases[0] != "Standard":
458-
self.headers["Variant"] = type(board).aliases[0]
459-
self.headers["FEN"] = board.fen()
475+
elif type(setup).aliases[0] != "Standard":
476+
self.headers["Variant"] = type(setup).aliases[0]
477+
self.headers["FEN"] = fen
460478
else:
461479
self.headers.pop("Variant", None)
462480

@@ -676,7 +694,7 @@ def __reversed__(self) -> Mainline[MainlineMapT]:
676694
def __repr__(self) -> str:
677695
return "<ReverseMainline at {:#x} ({})>".format(
678696
id(self),
679-
" ".join(ReverseMainline(self.stop, lambda node: node.move.uci())))
697+
" ".join(ReverseMainline(self.stop, lambda node: node._move().uci())))
680698

681699

682700
class BaseVisitor:
@@ -781,7 +799,7 @@ class GameBuilder(BaseVisitor):
781799
Creates a game model. Default visitor for :func:`~chess.pgn.read_game()`.
782800
"""
783801

784-
def __init__(self, *, Game=Game) -> None:
802+
def __init__(self, *, Game: Type[Game] = Game) -> None:
785803
self.Game = Game
786804

787805
def begin_game(self) -> None:
@@ -801,7 +819,9 @@ def visit_nag(self, nag: int) -> None:
801819
self.variation_stack[-1].nags.add(nag)
802820

803821
def begin_variation(self) -> None:
804-
self.variation_stack.append(self.variation_stack[-1].parent)
822+
parent = self.variation_stack[-1].parent
823+
assert parent is not None, "begin_variation called, but root node on top of stack"
824+
self.variation_stack.append(parent)
805825
self.in_variation = False
806826

807827
def end_variation(self) -> None:
@@ -862,7 +882,7 @@ def handle_error(self, error: Exception) -> None:
862882
LOGGER.exception("error during pgn parsing")
863883
self.game.errors.append(error)
864884

865-
def result(self):
885+
def result(self) -> Game:
866886
"""
867887
Returns the visited :class:`~chess.pgn.Game()`.
868888
"""
@@ -1029,7 +1049,7 @@ def visit_move(self, board: chess.Board, move: chess.Move) -> None:
10291049
def visit_result(self, result: str) -> None:
10301050
self.write_token(result + " ")
10311051

1032-
def result(self):
1052+
def result(self) -> str:
10331053
if self.current_line:
10341054
return "\n".join(itertools.chain(self.lines, [self.current_line.rstrip()])).rstrip()
10351055
else:

0 commit comments

Comments
 (0)