| import logging |
| from typing import Any, Dict, List, Optional |
|
|
| import numpy as np |
| import transformers |
|
|
| |
| |
| from .ultravox_model import UltravoxModel |
| from .ultravox_processing import UltravoxProcessor |
|
|
|
|
| class UltravoxPipeline(transformers.Pipeline): |
| def __init__( |
| self, |
| model: UltravoxModel, |
| tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None, |
| audio_processor: Optional[transformers.ProcessorMixin] = None, |
| **kwargs |
| ): |
| if tokenizer is None: |
| try: |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| model.config._name_or_path |
| ) |
| except: |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| model.config.text_model_id or model.config.text_config._name_or_path |
| ) |
|
|
| if audio_processor is None: |
| audio_processor = transformers.AutoProcessor.from_pretrained( |
| model.config.audio_model_id or model.config.audio_config._name_or_path |
| ) |
|
|
| super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
|
| self.processor = UltravoxProcessor( |
| audio_processor=audio_processor, |
| tokenizer=tokenizer, |
| stack_factor=model.config.stack_factor, |
| audio_context_size=model.audio_tower_context_length, |
| ) |
|
|
| def _sanitize_parameters(self, **kwargs): |
| generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"] |
| generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys} |
| return {}, generation_kwargs, {} |
|
|
| def preprocess(self, inputs: Dict[str, Any]): |
| turns: list = inputs.get("turns", []) |
|
|
| audio = inputs.get("audio", None) |
| |
| if isinstance(audio, np.ndarray): |
| if audio.dtype == np.float64: |
| audio = audio.astype(np.float32) |
| elif audio.dtype == np.int16: |
| audio = audio.astype(np.float32) / np.float32(32768.0) |
| elif audio.dtype == np.int32: |
| audio = audio.astype(np.float32) / np.float32(2147483648.0) |
|
|
| if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"): |
| prompt = inputs.get("prompt", "<|audio|>") |
| if "<|audio|>" not in prompt: |
| logging.warning( |
| "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt." |
| ) |
|
|
| prompt += " <|audio|>" |
| turns.append({"role": "user", "content": prompt}) |
|
|
| text = self.processor.tokenizer.apply_chat_template( |
| turns, add_generation_prompt=True, tokenize=False |
| ) |
|
|
| if "sampling_rate" not in inputs and audio is not None: |
| logging.warning( |
| "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate." |
| ) |
|
|
| output = self.processor( |
| text=text, |
| audio=audio, |
| sampling_rate=inputs.get("sampling_rate", 16000), |
| ) |
| if "audio_values" in output: |
| output["audio_values"] = output["audio_values"].to(self.model.dtype) |
|
|
| return output |
|
|
| def _forward( |
| self, |
| model_inputs: Dict[str, Any], |
| temperature: Optional[float] = None, |
| max_new_tokens: Optional[int] = None, |
| repetition_penalty: float = 1.1, |
| ) -> List[int]: |
| temperature = temperature or None |
| do_sample = temperature is not None |
|
|
| terminators = [self.tokenizer.eos_token_id] |
| if "<|eot_id|>" in self.tokenizer.added_tokens_encoder: |
| terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) |
|
|
| input_len = model_inputs["input_ids"].shape[1] |
|
|
| outputs = self.model.generate( |
| **model_inputs, |
| do_sample=do_sample, |
| temperature=temperature, |
| max_new_tokens=max_new_tokens, |
| repetition_penalty=repetition_penalty, |
| eos_token_id=terminators |
| ) |
| return outputs[0][input_len:] |
|
|
| def postprocess(self, model_outputs) -> str: |
| output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True) |
| return output_text |
|
|
|
|
| transformers.pipelines.PIPELINE_REGISTRY.register_pipeline( |
| "ultravox-pipeline", |
| pipeline_class=UltravoxPipeline, |
| pt_model=transformers.AutoModel, |
| type="multimodal", |
| ) |
|
|