Advertisement
klassekatze

Untitled

Mar 25th, 2025
483
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.27 KB | None | 0 0
  1. import re
  2. from flask import Flask, request, Response
  3. import requests
  4. import json
  5. from flask_cors import CORS
  6. from decoder import turn_token_into_id, convert_to_audio, tokens_decoder_sync
  7. import lameenc  # Import LAME MP3 encoder
  8.  
  9. app = Flask(__name__)
  10. CORS(app)
  11.  
  12. DEFAULT_VOICE = "tara"
  13. COMPLETION_API_URL = "http://127.0.0.1:5111/v1/completions"
  14.  
  15. def format_prompt(prompt: str, voice=DEFAULT_VOICE) -> str:
  16.     return f"<|audio|>{voice}: {prompt}<|eot_id|>"
  17.  
  18. @app.route("/v1/audio/speech", methods=["POST"])
  19. def generate_audio_stream():
  20.     req_data = request.get_json()
  21.    
  22.     text_input = req_data.get("input") or req_data.get("text")
  23.     if not text_input:
  24.         return {"error": "Missing 'text' or 'input' parameter"}, 400
  25.    
  26.     text_input = text_input.strip()
  27.     print(f"Received input: {text_input}")
  28.  
  29.     formatted_prompt = format_prompt(text_input)
  30.    
  31.     payload = {
  32.         "prompt": formatted_prompt,
  33.         "max_tokens": 2000,
  34.         "temperature": 0.4,
  35.         "top_p": 0.9,
  36.         "repetition_penalty": 1.1,
  37.         "stream": True,
  38.     }
  39.  
  40.     def sse_to_token_generator():  # Modified to remove WAV header
  41.         with requests.post(
  42.             COMPLETION_API_URL,
  43.             json=payload,
  44.             headers={"Accept": "text/event-stream"},
  45.             stream=True
  46.         ) as response:
  47.            
  48.             if response.status_code != 200:
  49.                 raise ValueError(f"API Error ({response.status_code})")
  50.  
  51.             buffer = []
  52.             token_count = 0
  53.  
  54.             for line in response.iter_lines(decode_unicode=True):
  55.                
  56.                 if not line.strip():  
  57.                     continue
  58.                
  59.                 if line.startswith("event: "):
  60.                     continue
  61.                    
  62.                 if line.startswith("data:"):
  63.                     data_part = line[5:].strip()  
  64.  
  65.                     if data_part == "[DONE]":
  66.                         print("Received [DONE], stopping generation")
  67.                         break
  68.  
  69.                     try:
  70.                         event_data = json.loads(data_part)
  71.                        
  72.                         text = event_data["choices"][0]["text"]
  73.                         tokens = re.findall(r"<custom_token_\d+>", text)
  74.                        
  75.                         for tok in tokens:
  76.                             token_id = turn_token_into_id(tok, token_count)
  77.                             if token_id is not None and token_id > 0:
  78.                                 buffer.append(token_id)
  79.                                 token_count += 1
  80.  
  81.                                 if token_count % 7 == 0 and len(buffer) >= 7:
  82.                                     multiframe = buffer[-28:]  # Keep last 28 tokens (original window size)
  83.                                     audio_pcm = convert_to_audio(multiframe, token_count)
  84.                                    
  85.                                     if audio_pcm:
  86.                                         yield audio_pcm  # Only PCM bytes now
  87.  
  88.                     except Exception as e:
  89.                         print(f"Parsing error on '{line}': {str(e)}")
  90.                         continue
  91.    
  92.     # New MP3 wrapper generator
  93.     def token_to_mp3_generator():
  94.         encoder = lameenc.Encoder()
  95.         encoder.set_bit_rate(64)          # Adjust bitrate as needed; lower is faster
  96.         encoder.set_in_sample_rate(24000) # Matches model's sample rate
  97.         encoder.set_channels(1)           # Mono audio from SNAC model
  98.         encoder.silence()                 # Disable LAME debug output
  99.         encoder.set_quality(7)            # Fastest encoding for low latency
  100.        
  101.         for pcm_chunk in sse_to_token_generator():  # Process each PCM chunk
  102.             assert isinstance(pcm_chunk, bytes)
  103.             mp3_data = encoder.encode(pcm_chunk)
  104.             if mp3_data:
  105.                 yield bytes(mp3_data)
  106.            
  107.         final_mp3 = encoder.flush()       # Flush any remaining data
  108.         if final_mp3:
  109.             yield bytes(final_mp3)
  110.  
  111.     return Response(
  112.         token_to_mp3_generator(),
  113.         mimetype="audio/mpeg",
  114.         headers={"Transfer-Encoding": "chunked"}
  115.     )
  116.  
  117. if __name__ == "__main__":
  118.     app.run(host="0.0.0.0", port=5000, debug=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement