diff --git a/example-fineweb.ipynb b/example-fineweb.ipynb index b3834cb..a12922e 100644 --- a/example-fineweb.ipynb +++ b/example-fineweb.ipynb @@ -20,7 +20,6 @@ "outputs": [], "source": [ "from tqdm import tqdm\n", - "import json\n", "\n", "import torch\n", "import math\n", @@ -33,6 +32,7 @@ "from model.KANaMoEv1 import KANaMoEv1\n", "\n", "from utils import load_model, quick_inference\n", + "from model.handler import save_pretrained\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] @@ -176,6 +176,14 @@ " device=device\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32cd9a68", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/model/handler.py b/model/handler.py index c659c62..74dcb8a 100644 --- a/model/handler.py +++ b/model/handler.py @@ -83,4 +83,31 @@ def save_pretrained(path_to_save: str, model: nn.Module): print(f"[INFO] Model and configuration saved successfully to {path_to_save}") - return path_to_save \ No newline at end of file + return path_to_save + + +def quick_inference(model: torch.nn.Module, tokens: torch.Tensor, max_new_tokens: int, tokenizer): + model.eval() # Set model to evaluation mode + with torch.no_grad(): # Disable gradient calculation for inference + for _ in range(max_new_tokens): + # Take the last 'max_seq_len' tokens as input to the model + tokens_conditioned = tokens[:, -model.args.max_seq_len:] + + # Get the model's predictions (logits) + logits, _ = model(tokens_conditioned) + + # Apply softmax to the last token's logits to get probabilities + probabilities = torch.softmax(logits[:, -1, :], dim=-1) + + # Sample the next token from the probability distribution + next_token = torch.multinomial(probabilities, num_samples=1) + + # Append the predicted token to the input sequence + tokens = torch.cat((tokens, next_token), dim=1) + + # Decode and print the token (convert from token ID to string) + decoded_token = tokenizer.decode(next_token.squeeze(dim=1).tolist(), skip_special_tokens=True) + print(decoded_token, end="", flush=True) + + # Return the final generated sequence (both as tokens and decoded text) + return tokens, tokenizer.decode(tokens.squeeze(dim=0).tolist(), skip_special_tokens=True) \ No newline at end of file diff --git a/utils.py b/utils.py index f4c55fe..c22abb9 100644 --- a/utils.py +++ b/utils.py @@ -3,14 +3,6 @@ import torch - -def load_model(model, file_name="trained_KANama_model.pth"): - model.load_state_dict(torch.load(file_name)) - return model - -def save_model(model, file_name="trained_KANama_model.pth"): - torch.save(model.state_dict(), file_name) - def createLossPlot(steps, losses, title="training"): plt.plot(steps, losses, linewidth=1) plt.xlabel("steps") @@ -90,31 +82,4 @@ def visualize_KANama(file_path, combined_file_path, individual_folder_path): plt.tight_layout() plt.savefig(combined_file_path) - plt.show() - - -def quick_inference(model: torch.nn.Module, tokens: torch.Tensor, max_new_tokens: int, tokenizer): - model.eval() # Set model to evaluation mode - with torch.no_grad(): # Disable gradient calculation for inference - for _ in range(max_new_tokens): - # Take the last 'max_seq_len' tokens as input to the model - tokens_conditioned = tokens[:, -model.args.max_seq_len:] - - # Get the model's predictions (logits) - logits, _ = model(tokens_conditioned) - - # Apply softmax to the last token's logits to get probabilities - probabilities = torch.softmax(logits[:, -1, :], dim=-1) - - # Sample the next token from the probability distribution - next_token = torch.multinomial(probabilities, num_samples=1) - - # Append the predicted token to the input sequence - tokens = torch.cat((tokens, next_token), dim=1) - - # Decode and print the token (convert from token ID to string) - decoded_token = tokenizer.decode(next_token.squeeze(dim=1).tolist(), skip_special_tokens=True) - print(decoded_token, end="", flush=True) - - # Return the final generated sequence (both as tokens and decoded text) - return tokens, tokenizer.decode(tokens.squeeze(dim=0).tolist(), skip_special_tokens=True) \ No newline at end of file + plt.show() \ No newline at end of file