Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import os | |
| import time | |
| # --- CONFIGURAÇÕES --- | |
| BOARD_SIZE = 8 | |
| DEVICE = torch.device("cpu") | |
| MODEL_PATH = "checkers_master_final.pth" | |
| st.set_page_config(page_title="AlphaCheckerZero", page_icon="♟️", layout="wide") | |
| # --- ESTILOS CSS PERSONALIZADOS --- | |
| st.markdown(""" | |
| <style> | |
| .board-container { | |
| display: grid; | |
| grid-template-columns: 30px repeat(8, 60px); | |
| grid-template-rows: 30px repeat(8, 60px); | |
| gap: 2px; | |
| background-color: #444; | |
| padding: 10px; | |
| border-radius: 10px; | |
| width: fit-content; | |
| margin: 0 auto; | |
| } | |
| .header-cell { | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| color: white; | |
| font-weight: bold; | |
| font-family: monospace; | |
| } | |
| .square { | |
| width: 60px; | |
| height: 60px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| font-size: 40px; | |
| cursor: default; | |
| } | |
| .white-square { background-color: #f0d9b5; } | |
| .black-square { background-color: #b58863; } | |
| .piece-white { | |
| color: #fff; | |
| text-shadow: 0 0 5px rgba(0,0,0,0.5); | |
| transform: scale(1.2); | |
| } | |
| .piece-black { | |
| color: #333; | |
| text-shadow: 0 0 2px rgba(255,255,255,0.3); | |
| transform: scale(1.2); | |
| } | |
| .king { border: 2px solid gold; border-radius: 50%; padding: 2px; box-shadow: 0 0 10px gold;} | |
| .stSelectbox label { font-size: 1.2rem; font-weight: bold; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- LÓGICA DO JOGO --- | |
| class Checkers: | |
| def get_initial_board(self): | |
| board = np.zeros((BOARD_SIZE, BOARD_SIZE), dtype=np.int8) | |
| for r in range(3): | |
| for c in range(BOARD_SIZE): | |
| if (r + c) % 2 == 1: board[r, c] = -1 | |
| for r in range(5, BOARD_SIZE): | |
| for c in range(BOARD_SIZE): | |
| if (r + c) % 2 == 1: board[r, c] = 1 | |
| return board | |
| def get_valid_moves(self, board, player): | |
| jumps = self._get_all_jumps(board, player) | |
| if jumps: return jumps | |
| moves = [] | |
| for r in range(BOARD_SIZE): | |
| for c in range(BOARD_SIZE): | |
| if board[r, c] * player > 0: moves.extend(self._get_simple_moves(board, r, c)) | |
| return moves | |
| def _get_simple_moves(self, board, r, c): | |
| moves = []; piece = board[r, c]; player = np.sign(piece) | |
| directions = [(-1, -1), (-1, 1)] if player == 1 else [(1, -1), (1, 1)] | |
| if abs(piece) == 2: directions.extend([(1, -1), (1, 1)] if player == 1 else [(-1, -1), (-1, 1)]) | |
| for dr, dc in directions: | |
| nr, nc = r + dr, c + dc | |
| if 0 <= nr < BOARD_SIZE and 0 <= nc < BOARD_SIZE and board[nr, nc] == 0: moves.append(((r, c), (nr, nc))) | |
| return moves | |
| def _get_all_jumps(self, board, player): | |
| all_jumps = [] | |
| for r in range(BOARD_SIZE): | |
| for c in range(BOARD_SIZE): | |
| if board[r, c] * player > 0: | |
| jumps = self._find_jump_sequences(np.copy(board), r, c) | |
| if jumps: all_jumps.extend(jumps) | |
| if not all_jumps: return [] | |
| max_len = max(len(j) for j in all_jumps) | |
| return [j for j in all_jumps if len(j) == max_len] | |
| def _find_jump_sequences(self, board, r, c, path=[]): | |
| piece = board[r, c]; player = np.sign(piece) | |
| if piece == 0: return [] | |
| directions = [(-1, -1), (-1, 1), (1, -1), (1, 1)] if abs(piece) == 2 else \ | |
| [(-1, -1), (-1, 1)] if player == 1 else [(1, -1), (1, 1)] | |
| found_jumps = [] | |
| for dr, dc in directions: | |
| mid_r, mid_c = r + dr, c + dc; end_r, end_c = r + 2*dr, c + 2*dc | |
| if 0 <= end_r < BOARD_SIZE and 0 <= end_c < BOARD_SIZE and \ | |
| board[mid_r, mid_c] * player < 0 and board[end_r, end_c] == 0: | |
| move = ((r, c), (end_r, end_c)) | |
| new_board = np.copy(board); new_board[end_r, end_c] = piece; new_board[r, c] = 0; new_board[mid_r, mid_c] = 0 | |
| next_jumps = self._find_jump_sequences(new_board, end_r, end_c, path + [move]) | |
| if next_jumps: found_jumps.extend(next_jumps) | |
| else: found_jumps.append(path + [move]) | |
| return found_jumps | |
| def apply_move(self, board, move): | |
| b_ = np.copy(board) | |
| # AQUI TAMBÉM PODERIA DAR ERRO, ENTÃO GARANTIMOS O FORMATO | |
| is_jump_chain = False | |
| if isinstance(move, list): | |
| is_jump_chain = True | |
| elif isinstance(move, tuple) and len(move) > 0 and isinstance(move[0], tuple) and len(move[0]) > 0 and isinstance(move[0][0], tuple): | |
| is_jump_chain = True | |
| sub_moves = move if is_jump_chain else [move] | |
| for (r1, c1), (r2, c2) in sub_moves: | |
| piece = b_[r1, c1] | |
| if piece == 0: continue | |
| b_[r2, c2] = piece; b_[r1, c1] = 0 | |
| if abs(r1 - r2) == 2: b_[(r1+r2)//2, (c1+c2)//2] = 0 | |
| r_final, c_final = sub_moves[-1][1]; p_final = b_[r_final, c_final] | |
| if p_final == 1 and r_final == 0: b_[r_final, c_final] = 2 | |
| if p_final == -1 and r_final == BOARD_SIZE-1: b_[r_final, c_final] = -2 | |
| return b_ | |
| def check_game_over(self, board, player): | |
| if not self.get_valid_moves(board, player): return -1 | |
| if not np.any(np.sign(board) == -player): return 1 | |
| return None | |
| def state_to_tensor(board, player): | |
| tensor = np.zeros((5, BOARD_SIZE, BOARD_SIZE), dtype=np.float32) | |
| tensor[0, board == player] = 1; tensor[1, board == player*2] = 1 | |
| tensor[2, board == -player] = 1; tensor[3, board == -player*2] = 1 | |
| if player == 1: tensor[4,:,:] = 1.0 | |
| return torch.from_numpy(tensor).unsqueeze(0).to(DEVICE) | |
| class PolicyValueNetwork(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| num_channels = 64 | |
| self.body = nn.Sequential(nn.Conv2d(5, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU(), | |
| nn.Conv2d(num_channels, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU(), | |
| nn.Conv2d(num_channels, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU()) | |
| self.policy_head = nn.Sequential(nn.Conv2d(num_channels, 4, 1), nn.BatchNorm2d(4), nn.ReLU(), nn.Flatten(), | |
| nn.Linear(4 * BOARD_SIZE * BOARD_SIZE, BOARD_SIZE * BOARD_SIZE)) | |
| self.value_head = nn.Sequential(nn.Conv2d(num_channels, 2, 1), nn.BatchNorm2d(2), nn.ReLU(), nn.Flatten(), | |
| nn.Linear(2 * BOARD_SIZE * BOARD_SIZE, 64), nn.ReLU(), | |
| nn.Linear(64, 1), nn.Tanh()) | |
| def forward(self, x): | |
| x = self.body(x); return self.policy_head(x), self.value_head(x) | |
| class MCTSNode: | |
| def __init__(self, parent=None, prior=0.0): | |
| self.parent = parent; self.prior = prior; self.children = {}; self.visits = 0; self.value_sum = 0.0 | |
| def get_value(self): return self.value_sum / self.visits if self.visits > 0 else 0.0 | |
| class MCTS: | |
| def __init__(self, game, model, sims=100, c_puct=1.5): | |
| self.game, self.model, self.sims, self.c_puct = game, model, sims, c_puct | |
| def run(self, board, player): | |
| root = MCTSNode() | |
| self._expand_and_evaluate(root, board, player) | |
| for _ in range(self.sims): | |
| node, search_board, search_player = root, np.copy(board), player | |
| search_path = [root] | |
| while node.children: | |
| move, node = self._select_child(node) | |
| search_board = self.game.apply_move(search_board, move); search_player *= -1; search_path.append(node) | |
| value = self.game.check_game_over(search_board, search_player) | |
| if value is None and node.visits == 0: value = self._expand_and_evaluate(node, search_board, search_player) | |
| elif value is None: value = node.get_value() | |
| for n in reversed(search_path): n.visits += 1; n.value_sum += value; value *= -1 | |
| moves = list(root.children.keys()) | |
| visits = np.array([root.children[m].visits for m in moves]) | |
| return moves, visits / (np.sum(visits) + 1e-9) | |
| def _select_child(self, node): | |
| sqrt_total_visits = np.sqrt(node.visits); best_move, max_score = None, -np.inf | |
| for move, child in node.children.items(): | |
| score = -child.get_value() + self.c_puct * child.prior * sqrt_total_visits / (1 + child.visits) | |
| if score > max_score: max_score, best_move = score, move | |
| return best_move, node.children[best_move] | |
| def _expand_and_evaluate(self, node, board, player): | |
| valid_moves = self.game.get_valid_moves(board, player) | |
| if not valid_moves: return -1.0 | |
| with torch.no_grad(): | |
| policy_logits, value_tensor = self.model(state_to_tensor(board, player)) | |
| value = value_tensor.item() | |
| policy_probs = F.softmax(policy_logits, dim=1).cpu().numpy()[0] | |
| move_priors = {}; total_prior = 0 | |
| for move in valid_moves: | |
| if isinstance(move, list): start_pos_tuple = move[0][0] | |
| else: start_pos_tuple = move[0] | |
| start_pos_idx = start_pos_tuple[0] * BOARD_SIZE + start_pos_tuple[1] | |
| prior = policy_probs[start_pos_idx] | |
| # IMPORTANTE: MCTS converte lista para tupla aqui para usar como chave de dicionário | |
| key = tuple(move) if isinstance(move, list) else move | |
| move_priors[key] = prior; total_prior += prior | |
| if total_prior > 0: | |
| for move_key, prior in move_priors.items(): node.children[move_key] = MCTSNode(parent=node, prior=prior / total_prior) | |
| else: | |
| for move in valid_moves: | |
| key = tuple(move) if isinstance(move, list) else move | |
| node.children[key] = MCTSNode(parent=node, prior=1.0 / len(valid_moves)) | |
| return value | |
| # --- INTERFACE GRÁFICA --- | |
| def load_model(): | |
| if not os.path.exists(MODEL_PATH): | |
| return None | |
| model = PolicyValueNetwork().to(DEVICE) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| if model is None: | |
| st.error(f"Arquivo '{MODEL_PATH}' não encontrado!") | |
| st.stop() | |
| if "board" not in st.session_state: | |
| game = Checkers() | |
| st.session_state.board = game.get_initial_board() | |
| st.session_state.player = 1 | |
| st.session_state.game_over = False | |
| st.session_state.message = "Bem-vindo à Arena! Você joga com as BRANCAS (⚪)." | |
| game = Checkers() | |
| mcts = MCTS(game, model, sims=150) | |
| def format_move_for_human(move): | |
| """ | |
| Formata o movimento para texto legível. | |
| Lida tanto com listas (pulos originais) quanto tuplas aninhadas (pulos convertidos pelo MCTS). | |
| """ | |
| is_jump = False | |
| # 1. Se for lista, é um pulo | |
| if isinstance(move, list): | |
| is_jump = True | |
| # 2. Se for tupla, precisamos ver o conteúdo para saber se é pulo ou movimento simples | |
| elif isinstance(move, tuple): | |
| # Se o primeiro item da tupla for OUTRA tupla (ex: ((r,c), (r,c))), então é um pulo que foi convertido | |
| # Um movimento simples teria um int como primeiro sub-item: ((2,3), (3,4)) -> move[0] é (2,3), move[0][0] é 2 (int). | |
| if len(move) > 0 and isinstance(move[0], tuple) and len(move[0]) > 0 and isinstance(move[0][0], tuple): | |
| is_jump = True | |
| if is_jump: | |
| # Pulo múltiplo ou captura simples | |
| path = " -> ".join([f"({r},{c})" for (r,c), _ in move] + [str(move[-1][1])]) | |
| return f"Salto/Captura: {path}" | |
| else: | |
| # Movimento simples | |
| (r1, c1), (r2, c2) = move | |
| return f"Mover de ({r1}, {c1}) para ({r2}, {c2})" | |
| def render_board_html(board): | |
| html = '<div class="board-container">' | |
| html += '<div class="header-cell"></div>' | |
| for c in range(8): html += f'<div class="header-cell">{c}</div>' | |
| for r in range(8): | |
| html += f'<div class="header-cell">{r}</div>' | |
| for c in range(8): | |
| color_class = "black-square" if (r + c) % 2 == 1 else "white-square" | |
| piece = board[r, c] | |
| content = "" | |
| if piece == 1: content = '<span class="piece-white">⚪</span>' | |
| elif piece == 2: content = '<span class="piece-white king">👑</span>' | |
| elif piece == -1: content = '<span class="piece-black">⚫</span>' | |
| elif piece == -2: content = '<span class="piece-black king">👑</span>' | |
| html += f'<div class="square {color_class}">{content}</div>' | |
| html += '</div>' | |
| return html | |
| st.title("♟️ AlphaCheckerZero Arena") | |
| col1, col2 = st.columns([1.5, 1]) | |
| with col1: | |
| st.markdown(render_board_html(st.session_state.board), unsafe_allow_html=True) | |
| with col2: | |
| st.write("### 🕹️ Painel de Controle") | |
| if st.session_state.game_over: | |
| st.warning(st.session_state.message) | |
| if st.button("🔄 Jogar Novamente", use_container_width=True): | |
| st.session_state.board = game.get_initial_board() | |
| st.session_state.player = 1 | |
| st.session_state.game_over = False | |
| st.session_state.message = "Novo jogo iniciado!" | |
| st.rerun() | |
| else: | |
| st.info(st.session_state.message) | |
| if st.session_state.player == 1: | |
| valid_moves = game.get_valid_moves(st.session_state.board, 1) | |
| if not valid_moves: | |
| st.session_state.game_over = True | |
| st.session_state.message = "Sem movimentos válidos. Você perdeu. 😔" | |
| st.rerun() | |
| move_map = {format_move_for_human(m): m for m in valid_moves} | |
| selected_desc = st.selectbox("Sua vez! Escolha o movimento:", list(move_map.keys())) | |
| if st.button("✅ Confirmar Jogada", type="primary", use_container_width=True): | |
| move = move_map[selected_desc] | |
| st.session_state.board = game.apply_move(st.session_state.board, move) | |
| st.session_state.player = -1 | |
| st.session_state.message = "A IA está calculando..." | |
| st.rerun() | |
| else: | |
| with st.spinner("A AlphaCheckerZero está pensando..."): | |
| time.sleep(0.2) | |
| valid_moves, policy = mcts.run(np.copy(st.session_state.board), -1) | |
| if not valid_moves: | |
| st.session_state.game_over = True | |
| st.session_state.message = "A IA travou! VOCÊ VENCEU! 🎉" | |
| st.rerun() | |
| move = valid_moves[np.argmax(policy)] | |
| # AQUI É ONDE OCORRIA O ERRO ANTES: | |
| # A IA retorna o movimento (que pode ser uma tupla de pulo) | |
| # E agora a função 'format_move_for_human' sabe lidar com isso. | |
| move_text = format_move_for_human(move) | |
| st.session_state.board = game.apply_move(st.session_state.board, move) | |
| st.session_state.player = 1 | |
| st.session_state.message = f"IA moveu: {move_text}. Sua vez!" | |
| st.rerun() | |
| st.markdown("---") | |
| st.caption("**Legenda:** ⚪ Suas Peças | ⚫ Peças da IA | 👑 Dama") | |