from Llama3Model import Llama3Model from cfg import LLAMA32_CONFIG, LLAMA_SIZE_STR import torch import time from tools import model_memory_size, generate, text_to_token_ids, token_ids_to_text from huggingface import chat_tokenizer, tokenizer model = Llama3Model(LLAMA32_CONFIG) if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") model.to(device) from safetensors.torch import load_file weights = load_file("llama32_weights.safetensors") model.load_state_dict(weights) print("Weights loaded successfully!") start = time.time() PROMPT = "What do llamas eat?" token_ids = generate( model=model, idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device), max_new_tokens=150, context_size=LLAMA32_CONFIG["context_length"], top_k=1, temperature=0. ) print(f"Time: {time.time() - start:.2f} sec") if torch.cuda.is_available(): max_mem_bytes = torch.cuda.max_memory_allocated() max_mem_gb = max_mem_bytes / (1024 ** 3) print(f"Max memory allocated: {max_mem_gb:.2f} GB") output_text = token_ids_to_text(token_ids, tokenizer) def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): # Find the index of the first occurrence of "<|end_header_id|>" index = text.find(header_end) if index != -1: # Return the substring starting after "<|end_header_id|>" return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace else: # If the token is not found, return the original text return text print("\n\nOutput text:\n\n", clean_text(output_text))