import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import os import time # --- CONFIGURAÇÕES --- BOARD_SIZE = 8 DEVICE = torch.device("cpu") MODEL_PATH = "checkers_master_final.pth" st.set_page_config(page_title="AlphaCheckerZero", page_icon="♟️", layout="wide") # --- ESTILOS CSS PERSONALIZADOS --- st.markdown(""" """, unsafe_allow_html=True) # --- LÓGICA DO JOGO --- class Checkers: def get_initial_board(self): board = np.zeros((BOARD_SIZE, BOARD_SIZE), dtype=np.int8) for r in range(3): for c in range(BOARD_SIZE): if (r + c) % 2 == 1: board[r, c] = -1 for r in range(5, BOARD_SIZE): for c in range(BOARD_SIZE): if (r + c) % 2 == 1: board[r, c] = 1 return board def get_valid_moves(self, board, player): jumps = self._get_all_jumps(board, player) if jumps: return jumps moves = [] for r in range(BOARD_SIZE): for c in range(BOARD_SIZE): if board[r, c] * player > 0: moves.extend(self._get_simple_moves(board, r, c)) return moves def _get_simple_moves(self, board, r, c): moves = []; piece = board[r, c]; player = np.sign(piece) directions = [(-1, -1), (-1, 1)] if player == 1 else [(1, -1), (1, 1)] if abs(piece) == 2: directions.extend([(1, -1), (1, 1)] if player == 1 else [(-1, -1), (-1, 1)]) for dr, dc in directions: nr, nc = r + dr, c + dc if 0 <= nr < BOARD_SIZE and 0 <= nc < BOARD_SIZE and board[nr, nc] == 0: moves.append(((r, c), (nr, nc))) return moves def _get_all_jumps(self, board, player): all_jumps = [] for r in range(BOARD_SIZE): for c in range(BOARD_SIZE): if board[r, c] * player > 0: jumps = self._find_jump_sequences(np.copy(board), r, c) if jumps: all_jumps.extend(jumps) if not all_jumps: return [] max_len = max(len(j) for j in all_jumps) return [j for j in all_jumps if len(j) == max_len] def _find_jump_sequences(self, board, r, c, path=[]): piece = board[r, c]; player = np.sign(piece) if piece == 0: return [] directions = [(-1, -1), (-1, 1), (1, -1), (1, 1)] if abs(piece) == 2 else \ [(-1, -1), (-1, 1)] if player == 1 else [(1, -1), (1, 1)] found_jumps = [] for dr, dc in directions: mid_r, mid_c = r + dr, c + dc; end_r, end_c = r + 2*dr, c + 2*dc if 0 <= end_r < BOARD_SIZE and 0 <= end_c < BOARD_SIZE and \ board[mid_r, mid_c] * player < 0 and board[end_r, end_c] == 0: move = ((r, c), (end_r, end_c)) new_board = np.copy(board); new_board[end_r, end_c] = piece; new_board[r, c] = 0; new_board[mid_r, mid_c] = 0 next_jumps = self._find_jump_sequences(new_board, end_r, end_c, path + [move]) if next_jumps: found_jumps.extend(next_jumps) else: found_jumps.append(path + [move]) return found_jumps def apply_move(self, board, move): b_ = np.copy(board) # AQUI TAMBÉM PODERIA DAR ERRO, ENTÃO GARANTIMOS O FORMATO is_jump_chain = False if isinstance(move, list): is_jump_chain = True elif isinstance(move, tuple) and len(move) > 0 and isinstance(move[0], tuple) and len(move[0]) > 0 and isinstance(move[0][0], tuple): is_jump_chain = True sub_moves = move if is_jump_chain else [move] for (r1, c1), (r2, c2) in sub_moves: piece = b_[r1, c1] if piece == 0: continue b_[r2, c2] = piece; b_[r1, c1] = 0 if abs(r1 - r2) == 2: b_[(r1+r2)//2, (c1+c2)//2] = 0 r_final, c_final = sub_moves[-1][1]; p_final = b_[r_final, c_final] if p_final == 1 and r_final == 0: b_[r_final, c_final] = 2 if p_final == -1 and r_final == BOARD_SIZE-1: b_[r_final, c_final] = -2 return b_ def check_game_over(self, board, player): if not self.get_valid_moves(board, player): return -1 if not np.any(np.sign(board) == -player): return 1 return None def state_to_tensor(board, player): tensor = np.zeros((5, BOARD_SIZE, BOARD_SIZE), dtype=np.float32) tensor[0, board == player] = 1; tensor[1, board == player*2] = 1 tensor[2, board == -player] = 1; tensor[3, board == -player*2] = 1 if player == 1: tensor[4,:,:] = 1.0 return torch.from_numpy(tensor).unsqueeze(0).to(DEVICE) class PolicyValueNetwork(nn.Module): def __init__(self): super().__init__() num_channels = 64 self.body = nn.Sequential(nn.Conv2d(5, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU(), nn.Conv2d(num_channels, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU(), nn.Conv2d(num_channels, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU()) self.policy_head = nn.Sequential(nn.Conv2d(num_channels, 4, 1), nn.BatchNorm2d(4), nn.ReLU(), nn.Flatten(), nn.Linear(4 * BOARD_SIZE * BOARD_SIZE, BOARD_SIZE * BOARD_SIZE)) self.value_head = nn.Sequential(nn.Conv2d(num_channels, 2, 1), nn.BatchNorm2d(2), nn.ReLU(), nn.Flatten(), nn.Linear(2 * BOARD_SIZE * BOARD_SIZE, 64), nn.ReLU(), nn.Linear(64, 1), nn.Tanh()) def forward(self, x): x = self.body(x); return self.policy_head(x), self.value_head(x) class MCTSNode: def __init__(self, parent=None, prior=0.0): self.parent = parent; self.prior = prior; self.children = {}; self.visits = 0; self.value_sum = 0.0 def get_value(self): return self.value_sum / self.visits if self.visits > 0 else 0.0 class MCTS: def __init__(self, game, model, sims=100, c_puct=1.5): self.game, self.model, self.sims, self.c_puct = game, model, sims, c_puct def run(self, board, player): root = MCTSNode() self._expand_and_evaluate(root, board, player) for _ in range(self.sims): node, search_board, search_player = root, np.copy(board), player search_path = [root] while node.children: move, node = self._select_child(node) search_board = self.game.apply_move(search_board, move); search_player *= -1; search_path.append(node) value = self.game.check_game_over(search_board, search_player) if value is None and node.visits == 0: value = self._expand_and_evaluate(node, search_board, search_player) elif value is None: value = node.get_value() for n in reversed(search_path): n.visits += 1; n.value_sum += value; value *= -1 moves = list(root.children.keys()) visits = np.array([root.children[m].visits for m in moves]) return moves, visits / (np.sum(visits) + 1e-9) def _select_child(self, node): sqrt_total_visits = np.sqrt(node.visits); best_move, max_score = None, -np.inf for move, child in node.children.items(): score = -child.get_value() + self.c_puct * child.prior * sqrt_total_visits / (1 + child.visits) if score > max_score: max_score, best_move = score, move return best_move, node.children[best_move] def _expand_and_evaluate(self, node, board, player): valid_moves = self.game.get_valid_moves(board, player) if not valid_moves: return -1.0 with torch.no_grad(): policy_logits, value_tensor = self.model(state_to_tensor(board, player)) value = value_tensor.item() policy_probs = F.softmax(policy_logits, dim=1).cpu().numpy()[0] move_priors = {}; total_prior = 0 for move in valid_moves: if isinstance(move, list): start_pos_tuple = move[0][0] else: start_pos_tuple = move[0] start_pos_idx = start_pos_tuple[0] * BOARD_SIZE + start_pos_tuple[1] prior = policy_probs[start_pos_idx] # IMPORTANTE: MCTS converte lista para tupla aqui para usar como chave de dicionário key = tuple(move) if isinstance(move, list) else move move_priors[key] = prior; total_prior += prior if total_prior > 0: for move_key, prior in move_priors.items(): node.children[move_key] = MCTSNode(parent=node, prior=prior / total_prior) else: for move in valid_moves: key = tuple(move) if isinstance(move, list) else move node.children[key] = MCTSNode(parent=node, prior=1.0 / len(valid_moves)) return value # --- INTERFACE GRÁFICA --- @st.cache_resource def load_model(): if not os.path.exists(MODEL_PATH): return None model = PolicyValueNetwork().to(DEVICE) model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.eval() return model model = load_model() if model is None: st.error(f"Arquivo '{MODEL_PATH}' não encontrado!") st.stop() if "board" not in st.session_state: game = Checkers() st.session_state.board = game.get_initial_board() st.session_state.player = 1 st.session_state.game_over = False st.session_state.message = "Bem-vindo à Arena! Você joga com as BRANCAS (⚪)." game = Checkers() mcts = MCTS(game, model, sims=150) def format_move_for_human(move): """ Formata o movimento para texto legível. Lida tanto com listas (pulos originais) quanto tuplas aninhadas (pulos convertidos pelo MCTS). """ is_jump = False # 1. Se for lista, é um pulo if isinstance(move, list): is_jump = True # 2. Se for tupla, precisamos ver o conteúdo para saber se é pulo ou movimento simples elif isinstance(move, tuple): # Se o primeiro item da tupla for OUTRA tupla (ex: ((r,c), (r,c))), então é um pulo que foi convertido # Um movimento simples teria um int como primeiro sub-item: ((2,3), (3,4)) -> move[0] é (2,3), move[0][0] é 2 (int). if len(move) > 0 and isinstance(move[0], tuple) and len(move[0]) > 0 and isinstance(move[0][0], tuple): is_jump = True if is_jump: # Pulo múltiplo ou captura simples path = " -> ".join([f"({r},{c})" for (r,c), _ in move] + [str(move[-1][1])]) return f"Salto/Captura: {path}" else: # Movimento simples (r1, c1), (r2, c2) = move return f"Mover de ({r1}, {c1}) para ({r2}, {c2})" def render_board_html(board): html = '