| | """ |
| | Chat interface for CosmicFish model downloaded from Hugging Face Hub. |
| | Uses safetensors format only for secure model loading. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import time |
| | import argparse |
| | import torch |
| | import numpy as np |
| | from termcolor import colored |
| | import logging |
| | import readline |
| | import re |
| | import textwrap |
| | import random |
| | from collections import defaultdict |
| | import json |
| |
|
| | |
| | try: |
| | from transformers import GPT2Tokenizer |
| | from huggingface_hub import hf_hub_download, snapshot_download |
| | HF_AVAILABLE = True |
| | except ImportError: |
| | HF_AVAILABLE = False |
| | print("Required libraries not available.") |
| | print("Install with: pip install transformers huggingface-hub") |
| | sys.exit(1) |
| |
|
| | |
| | try: |
| | from safetensors.torch import load_file |
| | SAFETENSORS_AVAILABLE = True |
| | except ImportError: |
| | SAFETENSORS_AVAILABLE = False |
| | print("Safetensors not available. Install with: pip install safetensors") |
| | sys.exit(1) |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s', |
| | handlers=[logging.StreamHandler(sys.stdout)] |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-90M" |
| |
|
| | |
| | DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n" |
| |
|
| |
|
| | class CosmicConfig: |
| | """Configuration class for CosmicFish.""" |
| |
|
| | def __init__(self, |
| | vocab_size=50257, |
| | block_size=512, |
| | n_layer=10, |
| | n_head=16, |
| | n_embd=640, |
| | bias=True, |
| | dropout=0.0, |
| | n_query_groups=4, |
| | eps=1e-6, |
| | use_rotary=True, |
| | use_swiglu=True, |
| | use_qk_norm=False, |
| | use_gqa=True): |
| | self.vocab_size = vocab_size |
| | self.block_size = block_size |
| | self.n_layer = n_layer |
| | self.n_head = n_head |
| | self.n_embd = n_embd |
| | self.bias = bias |
| | self.dropout = dropout |
| | self.eps = eps |
| | self.use_rotary = use_rotary |
| | self.use_swiglu = use_swiglu |
| | self.use_qk_norm = use_qk_norm |
| | self.use_gqa = use_gqa |
| | self.n_query_groups = n_query_groups if use_gqa else n_head |
| | |
| | assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups" |
| |
|
| |
|
| | class RMSNorm(torch.nn.Module): |
| | """Root Mean Square Normalization""" |
| |
|
| | def __init__(self, dim, eps=1e-6): |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = torch.nn.Parameter(torch.ones(dim)) |
| |
|
| | def forward(self, x): |
| | rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
| | return self.weight * (x / rms) |
| |
|
| |
|
| | def precompute_freqs_cis(dim, end, theta=10000.0): |
| | """Precompute the frequency tensor for complex exponentials (cis)""" |
| | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| | t = torch.arange(end, device=freqs.device) |
| | freqs = torch.outer(t, freqs) |
| | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| | return freqs_cis |
| |
|
| |
|
| | def apply_rotary_emb(xq, xk, freqs_cis): |
| | """Apply rotary embeddings to input tensors""" |
| | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| |
|
| | seq_len = xq_.size(2) |
| | if freqs_cis.size(0) < seq_len: |
| | raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}") |
| |
|
| | freqs_cis_seq = freqs_cis[:seq_len] |
| | xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3) |
| | xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3) |
| |
|
| | return xq_out.type_as(xq), xk_out.type_as(xk) |
| |
|
| |
|
| | class GroupedQueryAttention(torch.nn.Module): |
| | """Grouped Query Attention (GQA) implementation""" |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | assert config.n_embd % config.n_head == 0 |
| |
|
| | head_dim = config.n_embd // config.n_head |
| | self.head_dim = head_dim |
| | self.n_head = config.n_head |
| | self.n_embd = config.n_embd |
| | self.n_query_groups = config.n_query_groups |
| |
|
| | self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head |
| | qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim |
| |
|
| | self.c_attn = torch.nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias) |
| | self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| |
|
| | |
| | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
| | if not self.flash: |
| | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |
| | .view(1, 1, config.block_size, config.block_size)) |
| |
|
| | |
| | self.qk_norm = getattr(config, 'use_qk_norm', False) |
| | if self.qk_norm: |
| | self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6)) |
| | self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6)) |
| |
|
| | def forward(self, x, freqs_cis=None): |
| | B, T, C = x.size() |
| | qkv = self.c_attn(x) |
| | head_dim = C // self.n_head |
| |
|
| | q_size = self.n_head * head_dim |
| | k_size = self.kv_heads * head_dim |
| | v_size = self.kv_heads * head_dim |
| |
|
| | q, k, v = qkv.split([q_size, k_size, v_size], dim=2) |
| |
|
| | q = q.view(B, T, self.n_head, head_dim).transpose(1, 2) |
| | k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2) |
| | v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2) |
| |
|
| | |
| | if self.kv_heads < self.n_head: |
| | repeats = self.n_head // self.kv_heads |
| | k = k.repeat_interleave(repeats, dim=1) |
| | v = v.repeat_interleave(repeats, dim=1) |
| |
|
| | |
| | if freqs_cis is not None: |
| | q, k = apply_rotary_emb(q, k, freqs_cis) |
| |
|
| | |
| | if self.qk_norm: |
| | q = self.q_norm(q) |
| | k = self.k_norm(k) |
| |
|
| | |
| | if self.flash: |
| | y = torch.nn.functional.scaled_dot_product_attention( |
| | q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True |
| | ) |
| | else: |
| | att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32))) |
| | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) |
| | att = torch.nn.functional.softmax(att, dim=-1) |
| | y = att @ v |
| |
|
| | y = y.transpose(1, 2).contiguous().view(B, T, C) |
| | y = self.c_proj(y) |
| | return y |
| |
|
| |
|
| | class Block(torch.nn.Module): |
| | """Transformer block""" |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.ln_1 = RMSNorm(config.n_embd, eps=config.eps) |
| | self.ln_2 = RMSNorm(config.n_embd, eps=config.eps) |
| | self.attn = GroupedQueryAttention(config) |
| |
|
| | |
| | if config.use_swiglu: |
| | |
| | self.mlp = torch.nn.ModuleDict(dict( |
| | gate=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
| | up=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
| | down=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias), |
| | act=torch.nn.SiLU(), |
| | )) |
| | m = self.mlp |
| | self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x)) |
| | else: |
| | |
| | self.mlp = torch.nn.ModuleDict(dict( |
| | c_fc=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
| | c_proj=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias), |
| | act=torch.nn.GELU(), |
| | )) |
| | m = self.mlp |
| | self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x))) |
| |
|
| | def forward(self, x, freqs_cis=None): |
| | x = x + self.attn(self.ln_1(x), freqs_cis) |
| | x = x + self.mlpf(self.ln_2(x)) |
| | return x |
| |
|
| |
|
| | class CosmicFish(torch.nn.Module): |
| | """ |
| | CosmicFish model for inference only. |
| | Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| |
|
| | self.transformer = torch.nn.ModuleDict(dict( |
| | wte=torch.nn.Embedding(config.vocab_size, config.n_embd), |
| | h=torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| | ln_f=RMSNorm(config.n_embd, eps=config.eps), |
| | )) |
| |
|
| | self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| |
|
| | |
| | self.transformer.wte.weight = self.lm_head.weight |
| |
|
| | |
| | if config.use_rotary: |
| | head_dim = config.n_embd // config.n_head |
| | self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size) |
| | else: |
| | self.freqs_cis = None |
| | self.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd) |
| |
|
| | def get_num_params(self, non_embedding=True): |
| | """Return the number of parameters in the model.""" |
| | n_params = sum(p.numel() for p in self.parameters()) |
| | if non_embedding and hasattr(self.transformer, 'wpe'): |
| | n_params -= self.transformer.wpe.weight.numel() |
| | return n_params |
| |
|
| | def forward(self, idx, targets=None): |
| | """Forward pass through the model.""" |
| | device = idx.device |
| | b, t = idx.size() |
| | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
| |
|
| | |
| | tok_emb = self.transformer.wte(idx) |
| |
|
| | |
| | if self.config.use_rotary: |
| | x = tok_emb |
| | freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None |
| | else: |
| | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) |
| | pos_emb = self.transformer.wpe(pos) |
| | x = tok_emb + pos_emb |
| | freqs_cis = None |
| |
|
| | |
| | for block in self.transformer.h: |
| | x = block(x, freqs_cis) |
| |
|
| | |
| | x = self.transformer.ln_f(x) |
| |
|
| | |
| | if targets is not None: |
| | logits = self.lm_head(x) |
| | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| | else: |
| | |
| | logits = self.lm_head(x[:, [-1], :]) |
| | loss = None |
| |
|
| | return logits, loss |
| |
|
| | @torch.no_grad() |
| | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| | """ |
| | Generate text by sampling from the model, token by token. |
| | """ |
| | for _ in range(max_new_tokens): |
| | |
| | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] |
| |
|
| | |
| | logits, _ = self(idx_cond) |
| | logits = logits[:, -1, :] / temperature |
| |
|
| | |
| | if top_k is not None: |
| | v, _ = torch.topk(logits, top_k) |
| | logits[logits < v[:, [-1]]] = -float('Inf') |
| |
|
| | |
| | probs = torch.nn.functional.softmax(logits, dim=-1) |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| |
|
| | |
| | idx = torch.cat((idx, idx_next), dim=1) |
| |
|
| | return idx |
| |
|
| |
|
| | class RepetitionPenaltyLogitsProcessor: |
| | """Apply repetition penalty to prevent repeating tokens.""" |
| |
|
| | def __init__(self, penalty=1.2): |
| | self.penalty = penalty |
| |
|
| | def __call__(self, input_ids, scores): |
| | """Apply repetition penalty to logits where input_ids is already seen.""" |
| | score = torch.gather(scores, 1, input_ids) |
| | |
| | score = torch.where(score > 0, score / self.penalty, score * self.penalty) |
| | scores.scatter_(1, input_ids, score) |
| | return scores |
| |
|
| |
|
| | class CosmicFishChatSession: |
| | """Chat session for CosmicFish model from Hugging Face Hub.""" |
| |
|
| | def __init__(self, model, tokenizer, config): |
| | """Initialize chat session with model and configuration.""" |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | self.config = config |
| | self.device = next(model.parameters()).device |
| | self.history = [] |
| | self.history_tokens = [] |
| | self.max_history_tokens = config.max_history_tokens |
| | self.prompt_template = config.prompt_template |
| | self.human_prefix = config.human_prefix |
| | self.assistant_prefix = config.assistant_prefix |
| | self.end_of_turn = config.end_of_turn |
| | self.block_size = config.block_size |
| | self.debug_mode = config.debug_mode |
| | self.repetition_penalty = config.repetition_penalty |
| | self.min_tokens_to_generate = config.min_tokens_to_generate |
| | self.max_retries = 20 |
| |
|
| | self.fallback_responses = [ |
| | "I'd be happy to help with that. Could you provide more details about what specific information you're looking for?", |
| | "That's a topic I can provide information about. What specific aspects would you like to know?", |
| | "I understand your question. I can share factual information on this topic if you could specify what aspects you're interested in.", |
| | "I can help with your question. To give you the most relevant information, could you clarify what specific details you're looking for?", |
| | "I'd be glad to address your question. To provide the most helpful response, could you specify what particular aspects of this topic interest you?" |
| | ] |
| |
|
| | self.generation_failure_message = "I'm sorry, but I'm having difficulty generating a response to that prompt. Could you try rephrasing your question or asking something else?" |
| |
|
| | |
| | self.total_prompt_tokens = 0 |
| | self.total_generated_tokens = 0 |
| |
|
| | |
| | self.end_markers = [ |
| | f"{self.human_prefix}", |
| | "Human:", |
| | "\nHuman:", |
| | "\nH:", |
| | "H:", |
| | "<|endoftext|>", |
| | "Below is a conversation", |
| | "\nA:", |
| | "A:", |
| | "</s>", |
| | "User:", |
| | "\nUser:" |
| | ] |
| |
|
| | if config.display_welcome: |
| | self._print_welcome_message() |
| |
|
| | def _print_welcome_message(self): |
| | welcome_text = f""" |
| | {'=' * 80} |
| | Welcome to CosmicFish chat interface |
| | |
| | This is a {self.model.get_num_params() / 1e6:.1f}M parameter model. |
| | CosmicFish is an efficient LLM with an advanced architecture. |
| | |
| | Type your prompts and CosmicFish will respond. |
| | |
| | Special commands: |
| | - /help: Show this help message |
| | - /clear: Clear the conversation history |
| | - /exit or /quit: Exit the chat |
| | - /stats: Show token usage statistics |
| | - /save [filename]: Save the conversation |
| | - /load [filename]: Load a conversation |
| | - /temp [value]: Set temperature (between 0.1 and 2.0) |
| | - /penalty [value]: Set repetition penalty (1.0-2.0) |
| | - /debug: Toggle debug mode |
| | |
| | |
| | Note: CosmicFIsh may generate incorrect or fictional responses. Verify facts if needed. |
| | |
| | Visit https://cosmicfish.ai for more info |
| | |
| | |
| | Developed by Mistyoz AI (https://www.mistyoz.com) |
| | {'=' * 80} |
| | """ |
| | print(colored(welcome_text, 'cyan')) |
| |
|
| | def _format_prompt(self, user_input): |
| | """Format the complete prompt with history and current input.""" |
| | |
| | formatted_prompt = self.prompt_template |
| |
|
| | |
| | for entry in self.history: |
| | role, text = entry |
| | if role == "human": |
| | formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}" |
| | else: |
| | formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}" |
| |
|
| | |
| | formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}" |
| |
|
| | return formatted_prompt |
| |
|
| | def _tokenize(self, text): |
| | """Tokenize text and return token IDs.""" |
| | return self.tokenizer.encode(text) |
| |
|
| | def _update_history(self, user_input, response): |
| | """Update conversation history.""" |
| | |
| | self.history.append(("human", user_input)) |
| | self.history.append(("assistant", response)) |
| |
|
| | |
| | user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}") |
| | response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}") |
| |
|
| | self.history_tokens.extend(user_tokens) |
| | self.history_tokens.extend(response_tokens) |
| |
|
| | |
| | self.total_prompt_tokens += len(user_tokens) |
| | self.total_generated_tokens += len(response_tokens) |
| |
|
| | |
| | self._trim_history_if_needed() |
| |
|
| | def _trim_history_if_needed(self): |
| | """Trim history to fit within the context window.""" |
| | if len(self.history_tokens) > self.max_history_tokens: |
| | |
| | while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2: |
| | |
| | self.history = self.history[2:] |
| |
|
| | |
| | user_turn = self.history[0][1] |
| | assistant_turn = self.history[1][1] |
| | user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}")) |
| | assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}")) |
| |
|
| | |
| | self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:] |
| |
|
| | def _should_stop_generation(self, text): |
| | """Check if generation should stop based on end markers.""" |
| | for marker in self.end_markers: |
| | if marker in text: |
| | return True |
| | return False |
| |
|
| | def _clean_token_text(self, text): |
| | text = text.replace('��', "'") |
| | text = text.replace('�', "'") |
| | text = text.replace('\ufffd', "'") |
| | text = text.replace('\uFFFD', "'") |
| | text = text.replace('’', "'") |
| | text = text.replace('â€Å"', "'") |
| | text = text.replace('�', "'") |
| | text = text.replace('â€"', "'") |
| | text = text.replace('â€"', "'") |
| | return text |
| |
|
| | def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False): |
| | """Custom generate function with repetition penalty and optional live generation.""" |
| | model = self.model |
| | device = self.device |
| |
|
| | |
| | model.eval() |
| |
|
| | |
| | generated = input_ids.clone() |
| |
|
| | |
| | live_buffer = "" |
| |
|
| | |
| | rep_processor = RepetitionPenaltyLogitsProcessor(penalty=penalty) |
| |
|
| | |
| | tokens_generated = 0 |
| | min_tokens = self.min_tokens_to_generate |
| |
|
| | |
| | eot_token_id = self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 50256 |
| |
|
| | |
| | for _ in range(max_new_tokens): |
| | |
| | if generated.size(1) > self.block_size: |
| | context = generated[:, -self.block_size:] |
| | else: |
| | context = generated |
| |
|
| | |
| | with torch.no_grad(): |
| | logits, _ = model(context) |
| |
|
| | |
| | next_token_logits = logits[:, -1, :] |
| |
|
| | |
| | next_token_logits = next_token_logits / temperature |
| |
|
| | |
| | if penalty > 1.0: |
| | next_token_logits = rep_processor(context, next_token_logits) |
| |
|
| | |
| | if top_k is not None: |
| | indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
| | next_token_logits[indices_to_remove] = float('-inf') |
| |
|
| | |
| | probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
| |
|
| | |
| | next_token = torch.multinomial(probs, num_samples=1) |
| |
|
| | |
| | if next_token.item() == eot_token_id: |
| | if live: |
| | yield "", live_buffer, True |
| | break |
| |
|
| | |
| | generated = torch.cat((generated, next_token), dim=1) |
| | tokens_generated += 1 |
| |
|
| | |
| | if live: |
| | |
| | next_token_text = self.tokenizer.decode([next_token.item()]) |
| | |
| | next_token_text = self._clean_token_text(next_token_text) |
| | live_buffer += next_token_text |
| |
|
| | |
| | eot_marker_pos = live_buffer.find("<|endoftext|>") |
| | if eot_marker_pos != -1: |
| | |
| | live_buffer = live_buffer[:eot_marker_pos] |
| | yield "", live_buffer, True |
| | break |
| |
|
| | |
| | should_stop = tokens_generated >= min_tokens and self._should_stop_generation(live_buffer) |
| | yield next_token_text, live_buffer, should_stop |
| |
|
| | if should_stop: |
| | break |
| |
|
| | |
| | elif tokens_generated >= min_tokens: |
| | |
| | recent_text = self.tokenizer.decode(generated[0, -20:].tolist()) |
| | if self._should_stop_generation(recent_text): |
| | break |
| |
|
| | |
| | if tokens_generated == 0 and not live: |
| | if self.debug_mode: |
| | print(colored("\n[No tokens generated in this attempt]", "red")) |
| | return None |
| |
|
| | if not live: |
| | return generated |
| |
|
| | def generate_response(self, user_input): |
| | """Generate a response to the user input.""" |
| | |
| | prompt = self._format_prompt(user_input) |
| |
|
| | |
| | input_ids = torch.tensor(self._tokenize(prompt), dtype=torch.long).unsqueeze(0).to(self.device) |
| |
|
| | |
| | if input_ids.size(1) > self.block_size: |
| | |
| | instruction_tokens = self._tokenize(self.prompt_template) |
| | |
| | keep_from_beginning = len(instruction_tokens) |
| | keep_from_end = self.block_size - keep_from_beginning |
| |
|
| | |
| | if keep_from_end < 0: |
| | |
| | input_ids = input_ids[:, :self.block_size] |
| | else: |
| | |
| | input_ids = torch.cat([ |
| | input_ids[:, :keep_from_beginning], |
| | input_ids[:, -(keep_from_end):] |
| | ], dim=1) |
| |
|
| | |
| | start_time = time.time() |
| |
|
| | |
| | return self._generate_live_response(input_ids, user_input, start_time) |
| |
|
| | def _generate_live_response(self, input_ids, user_input, start_time): |
| | """Generate response with live token-by-token output.""" |
| | |
| | live_text = "" |
| | tokens_generated = 0 |
| | retry_count = 0 |
| |
|
| | |
| | while retry_count <= self.max_retries: |
| | if retry_count > 0: |
| | |
| | if retry_count % 2 == 0: |
| | |
| | temp_adjustment = min(0.2 * (retry_count // 2), 0.8) |
| | current_temp = min(self.config.temperature + temp_adjustment, 1.8) |
| | else: |
| | |
| | temp_adjustment = min(0.2 * ((retry_count + 1) // 2), 0.4) |
| | current_temp = max(self.config.temperature - temp_adjustment, 0.2) |
| |
|
| | if self.debug_mode: |
| | print(colored(f"\n[Live retry {retry_count}: Using temperature {current_temp:.2f}]", "yellow")) |
| | else: |
| | current_temp = self.config.temperature |
| |
|
| | |
| | live_text = "" |
| | tokens_generated = 0 |
| | generation_failed = False |
| |
|
| | |
| | try: |
| | |
| | for token_text, live_buffer, should_stop in self.generate_with_repetition_penalty( |
| | input_ids, |
| | max_new_tokens=self.config.max_new_tokens, |
| | temperature=current_temp, |
| | top_k=self.config.top_k, |
| | penalty=self.repetition_penalty, |
| | live=True |
| | ): |
| | |
| | if should_stop: |
| | |
| | live_text = live_buffer |
| | break |
| |
|
| | |
| | if token_text: |
| | live_text += token_text |
| | tokens_generated += 1 |
| | yield token_text, live_text, False |
| |
|
| | |
| | if not live_text or len(live_text.strip()) < 10: |
| | if self.debug_mode: |
| | print(colored("\n[Live generation produced empty or too short response, retrying]", "yellow")) |
| | generation_failed = True |
| | retry_count += 1 |
| | |
| | if retry_count <= self.max_retries: |
| | print("\r" + " " * 80 + "\r", end="") |
| | else: |
| | |
| | break |
| |
|
| | except Exception as e: |
| | if self.debug_mode: |
| | print(colored(f"\n[Live generation error: {str(e)}, retrying]", "red")) |
| | generation_failed = True |
| | retry_count += 1 |
| |
|
| | |
| | if generation_failed or not live_text or len(live_text.strip()) < 10: |
| | live_text = self.generation_failure_message |
| | if self.debug_mode: |
| | print(colored(f"\n[Returning failure message after {retry_count} live retries]", "red")) |
| |
|
| | |
| | time_taken = time.time() - start_time |
| | tokens_per_second = tokens_generated / time_taken if time_taken > 0 else 0 |
| |
|
| | |
| | self._update_history(user_input, live_text) |
| |
|
| | |
| | logger.debug(f"Generated {tokens_generated} tokens in {time_taken:.2f}s ({tokens_per_second:.2f} tokens/s)") |
| |
|
| | |
| | yield "", live_text, True |
| |
|
| | def execute_command(self, command): |
| | """Execute a special command prefixed with /.""" |
| | command = command.strip() |
| |
|
| | if command == '/help': |
| | self._print_welcome_message() |
| | return True |
| |
|
| | elif command == '/clear': |
| | self.history = [] |
| | self.history_tokens = [] |
| | print(colored("Conversation history cleared.", 'yellow')) |
| | return True |
| |
|
| | elif command in ['/exit', '/quit']: |
| | print(colored("Goodbye!", 'cyan')) |
| | return False |
| |
|
| | elif command == '/stats': |
| | prompt_tokens = self.total_prompt_tokens |
| | generated_tokens = self.total_generated_tokens |
| | total_tokens = prompt_tokens + generated_tokens |
| |
|
| | stats = f""" |
| | Token usage statistics: |
| | - Prompt tokens: {prompt_tokens} |
| | - Generated tokens: {generated_tokens} |
| | - Total tokens: {total_tokens} |
| | - Current history length: {len(self.history_tokens)} tokens |
| | - Current repetition penalty: {self.repetition_penalty} |
| | - Current temperature: {self.config.temperature} |
| | - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters) |
| | - Source: {DEFAULT_MODEL_REPO} |
| | - Format: Safetensors (secure) |
| | """ |
| | print(colored(stats, 'yellow')) |
| | return True |
| |
|
| | elif command == '/debug': |
| | self.debug_mode = not self.debug_mode |
| | self.config.debug_mode = self.debug_mode |
| | mode = "enabled" if self.debug_mode else "disabled" |
| | print(colored(f"Debug mode {mode}", 'yellow')) |
| | return True |
| |
|
| | elif command.startswith('/penalty '): |
| | try: |
| | penalty = float(command[9:].strip()) |
| | if 1.0 <= penalty <= 2.0: |
| | self.repetition_penalty = penalty |
| | print(colored(f"Repetition penalty set to {penalty}", 'yellow')) |
| | else: |
| | print(colored("Repetition penalty should be between 1.0 and 2.0", 'red')) |
| | except ValueError: |
| | print(colored("Invalid repetition penalty value. Please use a number between 1.0 and 2.0", 'red')) |
| | return True |
| |
|
| | elif command.startswith('/temp '): |
| | try: |
| | temp = float(command[6:].strip()) |
| | if 0.1 <= temp <= 2.0: |
| | self.config.temperature = temp |
| | print(colored(f"Temperature set to {temp}", 'yellow')) |
| | else: |
| | print(colored("Temperature should be between 0.1 and 2.0", 'red')) |
| | except ValueError: |
| | print(colored("Invalid temperature value. Please use a number between 0.1 and 2.0", 'red')) |
| | return True |
| |
|
| | elif command.startswith('/save '): |
| | filename = command[6:].strip() |
| | if not filename: |
| | print(colored("Please specify a filename: /save <filename>", 'red')) |
| | return True |
| |
|
| | try: |
| | |
| | os.makedirs('conversations', exist_ok=True) |
| |
|
| | |
| | if not filename.endswith('.txt'): |
| | filename += '.txt' |
| |
|
| | filepath = os.path.join('conversations', filename) |
| |
|
| | with open(filepath, 'w', encoding='utf-8') as f: |
| | for entry in self.history: |
| | role, text = entry |
| | prefix = self.human_prefix if role == "human" else self.assistant_prefix |
| | f.write(f"{prefix}{text}{self.end_of_turn}") |
| |
|
| | print(colored(f"Conversation saved to {filepath}", 'green')) |
| |
|
| | except Exception as e: |
| | print(colored(f"Error saving conversation: {str(e)}", 'red')) |
| |
|
| | return True |
| |
|
| | elif command.startswith('/load '): |
| | filename = command[6:].strip() |
| | if not filename: |
| | print(colored("Please specify a filename: /load <filename>", 'red')) |
| | return True |
| |
|
| | try: |
| | |
| | if not filename.endswith('.txt'): |
| | filename += '.txt' |
| |
|
| | filepath = os.path.join('conversations', filename) |
| |
|
| | if not os.path.exists(filepath): |
| | print(colored(f"File not found: {filepath}", 'red')) |
| | return True |
| |
|
| | with open(filepath, 'r', encoding='utf-8') as f: |
| | content = f.read() |
| |
|
| | |
| | self.history = [] |
| | self.history_tokens = [] |
| |
|
| | |
| | turns = content.split(self.end_of_turn) |
| | for turn in turns: |
| | turn = turn.strip() |
| | if not turn: |
| | continue |
| |
|
| | if turn.startswith(self.human_prefix): |
| | text = turn[len(self.human_prefix):].strip() |
| | self.history.append(("human", text)) |
| | elif turn.startswith(self.assistant_prefix): |
| | text = turn[len(self.assistant_prefix):].strip() |
| | self.history.append(("assistant", text)) |
| |
|
| | |
| | self.history_tokens = [] |
| | for entry in self.history: |
| | role, text = entry |
| | if role == "human": |
| | self.history_tokens.extend(self._tokenize(f"{self.human_prefix}{text}{self.end_of_turn}")) |
| | else: |
| | self.history_tokens.extend(self._tokenize(f"{self.assistant_prefix}{text}{self.end_of_turn}")) |
| |
|
| | print(colored(f"Loaded conversation from {filepath} ({len(self.history) // 2} turns)", 'green')) |
| |
|
| | |
| | for i in range(0, len(self.history), 2): |
| | if i < len(self.history): |
| | user_text = self.history[i][1] |
| | print(colored(f"\nYou: {user_text}", 'green')) |
| |
|
| | if i + 1 < len(self.history): |
| | assistant_text = self.history[i + 1][1] |
| | print(colored("CosmicFish: ", 'blue'), end="") |
| | for line in assistant_text.split('\n'): |
| | wrapped_lines = textwrap.wrap(line, width=100) if line.strip() else [''] |
| | for wrapped_line in wrapped_lines: |
| | print(wrapped_line) |
| |
|
| | except Exception as e: |
| | print(colored(f"Error loading conversation: {str(e)}", 'red')) |
| |
|
| | return True |
| |
|
| | else: |
| | print(colored(f"Unknown command: {command}. Type /help for available commands.", 'red')) |
| | return True |
| |
|
| |
|
| | def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'): |
| | """Download and load CosmicFish model from Hugging Face Hub (safetensors only)""" |
| | print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan")) |
| |
|
| | try: |
| | |
| | print("Downloading model files...") |
| | cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None) |
| | print(f"Model cached at: {cache_dir}") |
| |
|
| | |
| | config_path = os.path.join(cache_dir, "config.json") |
| | with open(config_path, "r") as f: |
| | config_dict = json.load(f) |
| |
|
| | |
| | config = CosmicConfig( |
| | vocab_size=config_dict["vocab_size"], |
| | block_size=config_dict["block_size"], |
| | n_layer=config_dict["n_layer"], |
| | n_head=config_dict["n_head"], |
| | n_embd=config_dict["n_embd"], |
| | bias=config_dict["bias"], |
| | dropout=0.0, |
| | eps=config_dict.get("eps", 1e-6), |
| | use_rotary=config_dict["use_rotary"], |
| | use_swiglu=config_dict["use_swiglu"], |
| | use_gqa=config_dict["use_gqa"], |
| | n_query_groups=config_dict["n_query_groups"], |
| | use_qk_norm=config_dict.get("use_qk_norm", False) |
| | ) |
| |
|
| | |
| | print("Creating model...") |
| | model = CosmicFish(config) |
| |
|
| | |
| | print("Loading weights from safetensors...") |
| | safetensors_path = os.path.join(cache_dir, "model.safetensors") |
| |
|
| | if not os.path.exists(safetensors_path): |
| | raise FileNotFoundError(f"model.safetensors not found in {cache_dir}. This model requires safetensors format.") |
| |
|
| | state_dict = load_file(safetensors_path) |
| |
|
| | |
| | if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: |
| | state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] |
| |
|
| | model.load_state_dict(state_dict) |
| | model.to(device) |
| | model.eval() |
| |
|
| | print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters") |
| | print(f"Device: {device}") |
| | return model, config |
| |
|
| | except Exception as e: |
| | print(colored(f"Error downloading/loading model: {str(e)}", "red")) |
| | print(colored("Make sure you have internet connection and the model repo exists", "yellow")) |
| | sys.exit(1) |
| |
|
| |
|
| | def load_tokenizer(): |
| | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| | return tokenizer |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Chat with CosmicFish") |
| |
|
| | |
| | parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO, |
| | help=f"Hugging Face model repository (default: {DEFAULT_MODEL_REPO})") |
| | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
| | help="Device to use (cuda or cpu)") |
| |
|
| | |
| | parser.add_argument("--temperature", type=float, default=0.5, |
| | help="Temperature for sampling (default: 0.7)") |
| | parser.add_argument("--max_tokens", type=int, default=512, |
| | help="Maximum number of tokens to generate per response") |
| | parser.add_argument("--min_tokens", type=int, default=10, |
| | help="Minimum number of tokens to generate per response") |
| | parser.add_argument("--top_k", type=int, default=40, |
| | help="Top-k sampling (0 to disable)") |
| | parser.add_argument("--repetition_penalty", type=float, default=1.2, |
| | help="Repetition penalty (1.0 = no penalty, 1.2 = mild, 1.5 = moderate)") |
| |
|
| | |
| | parser.add_argument("--human_prefix", type=str, default="Human: ", |
| | help="Prefix for human messages") |
| | parser.add_argument("--assistant_prefix", type=str, default="Assistant: ", |
| | help="Prefix for assistant messages") |
| | parser.add_argument("--end_of_turn", type=str, default="\n\n", |
| | help="Delimiter between conversation turns") |
| | parser.add_argument("--instruction", type=str, |
| | default=DEFAULT_PROMPT_TEMPLATE, |
| | help="Instruction prompt to prepend to the conversation") |
| | parser.add_argument("--max_history", type=int, default=512, |
| | help="Maximum number of tokens to keep in history") |
| |
|
| | |
| | parser.add_argument("--no_welcome", action="store_true", |
| | help="Don't display the welcome message") |
| | parser.add_argument("--debug", action="store_true", |
| | help="Enable debug mode") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | device = args.device |
| | if device == "cuda" and not torch.cuda.is_available(): |
| | print(colored("CUDA is not available, falling back to CPU", "yellow")) |
| | device = "cpu" |
| |
|
| | try: |
| | |
| | model, model_config = download_cosmicfish_from_hub(args.model_repo, device) |
| |
|
| | |
| | tokenizer = load_tokenizer() |
| |
|
| | |
| | class ChatConfig: |
| | def __init__(self, args, block_size): |
| | self.device = device |
| | self.temperature = args.temperature |
| | self.max_new_tokens = args.max_tokens |
| | self.min_tokens_to_generate = args.min_tokens |
| | self.top_k = args.top_k |
| | self.human_prefix = args.human_prefix |
| | self.assistant_prefix = args.assistant_prefix |
| | self.end_of_turn = args.end_of_turn |
| | self.prompt_template = args.instruction |
| | self.max_history_tokens = args.max_history |
| | self.display_welcome = not args.no_welcome |
| | self.block_size = block_size |
| | self.debug_mode = args.debug |
| | self.repetition_penalty = args.repetition_penalty |
| |
|
| | config = ChatConfig(args, model_config.block_size) |
| |
|
| | |
| | chat = CosmicFishChatSession(model, tokenizer, config) |
| |
|
| | |
| | print(colored("\nCosmicFish initialized from Hugging Face! Type your message (or /help for commands).\n", 'cyan')) |
| |
|
| | while True: |
| | try: |
| | |
| | user_input = input(colored("You: ", 'green')) |
| |
|
| | |
| | if user_input.startswith('/'): |
| | |
| | if not chat.execute_command(user_input): |
| | break |
| | continue |
| |
|
| | |
| | if not user_input.strip(): |
| | continue |
| |
|
| | |
| | live_buffer = "" |
| | final_response = None |
| |
|
| | |
| | response_generator = chat.generate_response(user_input) |
| |
|
| | try: |
| | |
| | print(colored("CosmicFish: ", 'blue'), end="") |
| | sys.stdout.flush() |
| |
|
| | for token, live_text, is_done in response_generator: |
| | |
| | if is_done: |
| | final_response = live_text |
| | |
| | if not live_buffer: |
| | print(final_response, end="") |
| | break |
| | if token: |
| | |
| | if "<|endoftext|>" in token: |
| | token = token.replace("<|endoftext|>", "") |
| | if token: |
| | print(token, end="", flush=True) |
| | break |
| |
|
| | |
| | print(token, end="", flush=True) |
| | live_buffer += token |
| |
|
| | except KeyboardInterrupt: |
| | |
| | print("\n[Generation interrupted]") |
| | final_response = "I was going to respond, but I'll stop here since you interrupted." |
| |
|
| | |
| | print() |
| |
|
| | except KeyboardInterrupt: |
| | print("\n\nKeyboard interrupt detected. Type /exit to quit or continue chatting.") |
| |
|
| | except Exception as e: |
| | print(colored(f"\nError: {str(e)}", 'red')) |
| | logger.error(f"Error in chat loop: {str(e)}", exc_info=True) |
| |
|
| | except Exception as e: |
| | print(colored(f"Error setting up chat: {str(e)}", 'red')) |
| | logger.error(f"Error setting up chat: {str(e)}", exc_info=True) |
| | sys.exit(1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | try: |
| | main() |
| | except Exception as e: |
| | logger.error(f"Fatal error: {str(e)}", exc_info=True) |
| | sys.exit(1) |
| |
|