| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from typing import Optional, Dict, Any |
| import logging |
| import asyncio |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| class Charm15Chatbot: |
| def __init__( |
| self, |
| model_path: str, |
| device: Optional[str] = None, |
| tokenizer_kwargs: Optional[Dict[str, Any]] = None, |
| model_kwargs: Optional[Dict[str, Any]] = None, |
| ): |
| """ |
| Initialize the chatbot. |
| |
| Args: |
| model_path (str): Path or name of the pre-trained model. |
| device (str, optional): Device to run the model on (e.g., "cuda" or "cpu"). Defaults to "cuda" if available. |
| tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. |
| model_kwargs (dict, optional): Additional arguments for the model. |
| """ |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.tokenizer_kwargs = tokenizer_kwargs or {} |
| self.model_kwargs = model_kwargs or {} |
|
|
| |
| logger.info(f"Loading model and tokenizer from {model_path}...") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path, **self.tokenizer_kwargs) |
| self.model = AutoModelForCausalLM.from_pretrained(model_path, **self.model_kwargs).to(self.device) |
| self.model.eval() |
| logger.info("Model and tokenizer loaded successfully.") |
|
|
| def generate_response( |
| self, |
| input_text: str, |
| max_length: int = 512, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| top_k: Optional[int] = None, |
| repetition_penalty: float = 1.0, |
| **kwargs, |
| ) -> str: |
| """ |
| Generate a response to the input text. |
| |
| Args: |
| input_text (str): The input prompt. |
| max_length (int): Maximum length of the generated text. |
| temperature (float): Sampling temperature (higher = more random). |
| top_p (float): Top-p (nucleus) sampling. |
| top_k (int): Top-k sampling. |
| repetition_penalty (float): Penalty for repeating tokens. |
| **kwargs: Additional arguments for model.generate(). |
| |
| Returns: |
| str: The generated response. |
| """ |
| try: |
| inputs = self.tokenizer( |
| input_text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=1024, |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| output = self.model.generate( |
| **inputs, |
| max_length=max_length, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty, |
| pad_token_id=self.tokenizer.eos_token_id, |
| **kwargs, |
| ) |
|
|
| response = self.tokenizer.decode(output[0], skip_special_tokens=True) |
| logger.info("Response generated successfully.") |
| return response |
| except Exception as e: |
| logger.error(f"Error generating response: {e}") |
| raise |
|
|
| async def async_generate( |
| self, |
| input_text: str, |
| max_length: int = 512, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| top_k: Optional[int] = None, |
| repetition_penalty: float = 1.0, |
| **kwargs, |
| ) -> str: |
| """ |
| Asynchronously generate a response to the input text. |
| |
| Args: |
| input_text (str): The input prompt. |
| max_length (int): Maximum length of the generated text. |
| temperature (float): Sampling temperature (higher = more random). |
| top_p (float): Top-p (nucleus) sampling. |
| top_k (int): Top-k sampling. |
| repetition_penalty (float): Penalty for repeating tokens. |
| **kwargs: Additional arguments for model.generate(). |
| |
| Returns: |
| str: The generated response. |
| """ |
| return await asyncio.to_thread( |
| self.generate_response, |
| input_text, |
| max_length=max_length, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty, |
| **kwargs, |
| ) |