| import os |
| import io |
| import base64 |
| import torch |
| import numpy as np |
| from transformers import BarkModel, BarkProcessor |
| from typing import Dict, List, Any |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the handler for Bark text-to-speech model. |
| Args: |
| path (str, optional): Path to the model directory. Defaults to "". |
| """ |
| self.path = path |
| self.model = None |
| self.processor = None |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.initialized = False |
|
|
| def setup(self, **kwargs): |
| """ |
| Load the model and processor. |
| Args: |
| **kwargs: Additional arguments. |
| """ |
| |
| self.model = BarkModel.from_pretrained(self.path) |
| self.model.to(self.device) |
| |
| |
| self.processor = BarkProcessor.from_pretrained(self.path) |
| |
| self.initialized = True |
| print(f"Bark model loaded on {self.device}") |
|
|
| def preprocess(self, request: Dict) -> Dict: |
| """ |
| Process the input request before inference. |
| Args: |
| request (Dict): The request data containing text to convert to speech. |
| Returns: |
| Dict: Processed inputs for the model. |
| """ |
| if not self.initialized: |
| self.setup() |
| |
| inputs = {} |
| |
| |
| if "inputs" in request: |
| if isinstance(request["inputs"], str): |
| |
| inputs["text"] = request["inputs"] |
| elif isinstance(request["inputs"], list): |
| |
| inputs["text"] = request["inputs"][0] |
| |
| |
| params = request.get("parameters", {}) |
| |
| |
| if "speaker_id" in params: |
| inputs["speaker_id"] = params["speaker_id"] |
| elif "voice_preset" in params: |
| inputs["voice_preset"] = params["voice_preset"] |
| |
| |
| if "temperature" in params: |
| inputs["temperature"] = params.get("temperature", 0.7) |
| |
| return inputs |
|
|
| def inference(self, inputs: Dict) -> Dict: |
| """ |
| Run model inference on the processed inputs. |
| Args: |
| inputs (Dict): Processed inputs for the model. |
| Returns: |
| Dict: Model outputs. |
| """ |
| text = inputs.get("text", "") |
| if not text: |
| return {"error": "No text provided for speech generation"} |
| |
| |
| speaker_id = inputs.get("speaker_id", None) |
| voice_preset = inputs.get("voice_preset", None) |
| temperature = inputs.get("temperature", 0.7) |
| |
| |
| input_ids = self.processor(text).to(self.device) |
| |
| |
| with torch.no_grad(): |
| if speaker_id: |
| |
| speech_output = self.model.generate( |
| input_ids=input_ids, |
| speaker_id=speaker_id, |
| temperature=temperature |
| ) |
| elif voice_preset: |
| |
| speech_output = self.model.generate( |
| input_ids=input_ids, |
| voice_preset=voice_preset, |
| temperature=temperature |
| ) |
| else: |
| |
| speech_output = self.model.generate( |
| input_ids=input_ids, |
| temperature=temperature |
| ) |
| |
| |
| audio_array = speech_output.cpu().numpy().squeeze() |
| |
| return {"audio_array": audio_array, "sample_rate": self.model.generation_config.sample_rate} |
|
|
| def postprocess(self, inference_output: Dict) -> Dict: |
| """ |
| Process the model outputs after inference. |
| Args: |
| inference_output (Dict): Model outputs. |
| Returns: |
| Dict: Processed outputs ready for the response. |
| """ |
| if "error" in inference_output: |
| return {"error": inference_output["error"]} |
| |
| audio_array = inference_output.get("audio_array") |
| sample_rate = inference_output.get("sample_rate", 24000) |
| |
| |
| try: |
| import scipy.io.wavfile as wav |
| audio_buffer = io.BytesIO() |
| wav.write(audio_buffer, sample_rate, audio_array) |
| audio_buffer.seek(0) |
| audio_data = audio_buffer.read() |
| |
| |
| audio_base64 = base64.b64encode(audio_data).decode("utf-8") |
| |
| return { |
| "audio": audio_base64, |
| "sample_rate": sample_rate, |
| "format": "wav" |
| } |
| except Exception as e: |
| return {"error": f"Error converting audio: {str(e)}"} |
|
|
| def __call__(self, data: Dict) -> Dict: |
| """ |
| Main entry point for the handler. |
| Args: |
| data (Dict): Request data. |
| Returns: |
| Dict: Response data. |
| """ |
| |
| if not self.initialized: |
| self.setup() |
| |
| |
| try: |
| inputs = self.preprocess(data) |
| outputs = self.inference(inputs) |
| response = self.postprocess(outputs) |
| return response |
| except Exception as e: |
| return {"error": f"Error processing request: {str(e)}"} |
|
|