Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import re
- from flask import Flask, request, Response
- import requests
- import json
- from flask_cors import CORS
- from decoder import turn_token_into_id, convert_to_audio, tokens_decoder_sync
- import lameenc # Import LAME MP3 encoder
- app = Flask(__name__)
- CORS(app)
- DEFAULT_VOICE = "tara"
- COMPLETION_API_URL = "http://127.0.0.1:5111/v1/completions"
- def format_prompt(prompt: str, voice=DEFAULT_VOICE) -> str:
- return f"<|audio|>{voice}: {prompt}<|eot_id|>"
- @app.route("/v1/audio/speech", methods=["POST"])
- def generate_audio_stream():
- req_data = request.get_json()
- text_input = req_data.get("input") or req_data.get("text")
- if not text_input:
- return {"error": "Missing 'text' or 'input' parameter"}, 400
- text_input = text_input.strip()
- print(f"Received input: {text_input}")
- formatted_prompt = format_prompt(text_input)
- payload = {
- "prompt": formatted_prompt,
- "max_tokens": 2000,
- "temperature": 0.4,
- "top_p": 0.9,
- "repetition_penalty": 1.1,
- "stream": True,
- }
- def sse_to_token_generator(): # Modified to remove WAV header
- with requests.post(
- COMPLETION_API_URL,
- json=payload,
- headers={"Accept": "text/event-stream"},
- stream=True
- ) as response:
- if response.status_code != 200:
- raise ValueError(f"API Error ({response.status_code})")
- buffer = []
- token_count = 0
- for line in response.iter_lines(decode_unicode=True):
- if not line.strip():
- continue
- if line.startswith("event: "):
- continue
- if line.startswith("data:"):
- data_part = line[5:].strip()
- if data_part == "[DONE]":
- print("Received [DONE], stopping generation")
- break
- try:
- event_data = json.loads(data_part)
- text = event_data["choices"][0]["text"]
- tokens = re.findall(r"<custom_token_\d+>", text)
- for tok in tokens:
- token_id = turn_token_into_id(tok, token_count)
- if token_id is not None and token_id > 0:
- buffer.append(token_id)
- token_count += 1
- if token_count % 7 == 0 and len(buffer) >= 7:
- multiframe = buffer[-28:] # Keep last 28 tokens (original window size)
- audio_pcm = convert_to_audio(multiframe, token_count)
- if audio_pcm:
- yield audio_pcm # Only PCM bytes now
- except Exception as e:
- print(f"Parsing error on '{line}': {str(e)}")
- continue
- # New MP3 wrapper generator
- def token_to_mp3_generator():
- encoder = lameenc.Encoder()
- encoder.set_bit_rate(64) # Adjust bitrate as needed; lower is faster
- encoder.set_in_sample_rate(24000) # Matches model's sample rate
- encoder.set_channels(1) # Mono audio from SNAC model
- encoder.silence() # Disable LAME debug output
- encoder.set_quality(7) # Fastest encoding for low latency
- for pcm_chunk in sse_to_token_generator(): # Process each PCM chunk
- assert isinstance(pcm_chunk, bytes)
- mp3_data = encoder.encode(pcm_chunk)
- if mp3_data:
- yield bytes(mp3_data)
- final_mp3 = encoder.flush() # Flush any remaining data
- if final_mp3:
- yield bytes(final_mp3)
- return Response(
- token_to_mp3_generator(),
- mimetype="audio/mpeg",
- headers={"Transfer-Encoding": "chunked"}
- )
- if __name__ == "__main__":
- app.run(host="0.0.0.0", port=5000, debug=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement