Skip to content

Commit

Permalink
Refactor code and fix type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
miskibin committed Nov 28, 2023
1 parent 6f66e4a commit 4e51ad1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 21 deletions.
5 changes: 0 additions & 5 deletions draughts/boards/american.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ class Board(BaseBoard):

size = int(np.sqrt(len(STARTING_POSITION) * 2))

# def __init__(
# self, starting_position=STARTING_POSITION, turn=STARTING_COLOR, *args, **kwargs
# ) -> None:
# super().__init__(starting_position, turn, *args, **kwargs)

@property
def is_draw(self) -> bool:
return self.is_threefold_repetition
Expand Down
6 changes: 3 additions & 3 deletions draughts/boards/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from abc import ABC, abstractproperty
from typing import Generator, Literal
from typing import Generator, Literal, Optional
from typing import Type
import numpy as np

Expand Down Expand Up @@ -109,8 +109,8 @@ def __init_subclass__(cls, **kwargs):

def __init__(
self,
starting_position: np.ndarray = None,
turn: Color = None,
starting_position: Optional[np.ndarray] = None,
turn: Optional[Color] = None,
) -> None:
"""
Initializes the board with a starting position.
Expand Down
8 changes: 4 additions & 4 deletions draughts/move.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self,
visited_squares: list[int],
captured_list: list[int] = [],
captured_entities: list[Figure.value] = [],
captured_entities: list[int] = [],
is_promotion: bool = False,
) -> None:
self.square_list = visited_squares
Expand Down Expand Up @@ -115,10 +115,10 @@ def from_uci(cls, move: str, legal_moves: Generator) -> Move:
else:
raise ValueError(f"Invalid move {move}.")

move = Move([int(step) - 1 for step in steps])
move_obj = Move([int(step) - 1 for step in steps])
for legal_move in legal_moves:
if legal_move == move:
if legal_move == move_obj:
return legal_move
raise ValueError(
f"{str(move)} is correct, but not legal in given position.\n Legal moves are: {list(map(str,legal_moves))}"
f"{str(move_obj)} is correct, but not legal in given position.\n Legal moves are: {list(map(str,legal_moves))}"
)
14 changes: 5 additions & 9 deletions draughts/server/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from collections import defaultdict
from pathlib import Path
from typing import Literal
from typing import Literal, Callable

import numpy as np
import uvicorn
Expand All @@ -10,7 +10,6 @@
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field

from draughts.boards.base import BaseBoard, Color


Expand All @@ -30,7 +29,9 @@ class Server:
def __init__(
self,
board: BaseBoard,
get_best_move_method: callable = None,
get_best_move_method: Callable = lambda board: np.random.choice(
list(board.legal_moves)
),
):
self.get_best_move_method = get_best_move_method
self.board = board
Expand All @@ -51,10 +52,6 @@ def __init__(
)
self.router.add_api_route("/pop", self.pop, methods=["GET"])
self.APP.include_router(self.router)
if not get_best_move_method:
self.get_best_move_method = lambda board: np.random.choice(
list(board.legal_moves)
)

def get_fen(self):
return {"fen": self.board.fen}
Expand Down Expand Up @@ -88,11 +85,10 @@ def position_json(self) -> PositionResponse:
history.append([(idx // 2) + 1, str(stack[idx])])
else:
history[-1].append(str(stack[idx]))
turn = "white" if self.board.turn == Color.WHITE else "black"
return PositionResponse(
position=self.board.friendly_form.tolist(),
history=history,
turn=turn,
turn="white" if self.board.turn == Color.WHITE else "black",
)

def get_position(self, request: Request) -> PositionResponse:
Expand Down

0 comments on commit 4e51ad1

Please sign in to comment.