OpenVoice Apple M3 Max硬件优化

`# 创建虚拟环境
python3 -m venv openvoice_env
source openvoice_env/bin/activate

安装依赖

pip install -e .
pip install git+https://github.com/myshell-ai/MeloTTS.git
pip install psutil unidecode`

OpenVoice V2检查点

curl -L -o checkpoints_v2_0417.zip “https://myshell-public-repo-host.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip
unzip checkpoints_v2_0417.zip

文件结构验证

checkpoints_v2/
├── converter/
│ ├── config.json
│ └── checkpoint.pth (131MB)
└── base_speakers/ses/
├── jp.pth (日语说话者嵌入)
├── en-default.pth
└── 其他语言…

最佳选择:M3 Max快速模式

cd /Users/Documents/Openvoice/OpenVoice
source openvoice_env/bin/activate
export TOKENIZERS_PARALLELISM=false
python -m openvoice.openvoice_app_m3_fast --share

import os
import torch
import argparse
import gradio as gr
import langid
import soundfile as sf
import numpy as np
from openvoice import se_extractor
from openvoice.api import ToneColorConverter
import time
import multiprocessing

# 设置环境变量避免tokenizer警告
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

parser = argparse.ArgumentParser()
parser.add_argument("--share", action='store_true', default=False, help="make link public")
args = parser.parse_args()

# M3 Max优化配置
def get_optimized_device():
    """获取M3 Max的最佳设备配置"""
    if torch.backends.mps.is_available():
        print("🚀 Using Apple M3 Max GPU (MPS) for acceleration!")
        return "mps"
    else:
        print("⚠️ MPS not available, using CPU with optimizations")
        torch.set_num_threads(min(16, multiprocessing.cpu_count()))
        return "cpu"

# 设备配置
device = get_optimized_device()
print(f"Device: {device}")
print(f"CPU cores: {multiprocessing.cpu_count()}")
print(f"PyTorch threads: {torch.get_num_threads()}")

# V2 checkpoints
ckpt_converter_v2 = 'checkpoints_v2/converter'
output_dir = 'outputs_m3_fast'
os.makedirs(output_dir, exist_ok=True)

# 优化的模型加载
print("Loading OpenVoice V2 with M3 Max optimizations...")
start_time = time.time()

tone_color_converter_v2 = ToneColorConverter(f'{ckpt_converter_v2}/config.json', device=device)
tone_color_converter_v2.load_ckpt(f'{ckpt_converter_v2}/checkpoint.pth')

# 如果使用MPS,启用优化
if device == "mps":
    try:
        torch.backends.mps.empty_cache()
    except AttributeError:
        pass
    print("✅ MPS optimizations enabled")

load_time = time.time() - start_time
print(f"✅ Model loaded in {load_time:.2f} seconds")

# 预加载说话者嵌入到GPU内存
speaker_embeddings_v2 = {}
ses_dir = 'checkpoints_v2/base_speakers/ses'
print("Preloading speaker embeddings to GPU...")

# 语言映射
language_file_mapping = {
    'en': 'en-default.pth',
    'jp': 'jp.pth',
    'zh': 'zh.pth', 
    'es': 'es.pth',
    'fr': 'fr.pth',
    'kr': 'kr.pth'
}

for lang_code, filename in language_file_mapping.items():
    file_path = f'{ses_dir}/{filename}'
    if os.path.exists(file_path):
        embedding = torch.load(file_path, map_location=device)
        if device == "mps":
            embedding = embedding.to(device)
        speaker_embeddings_v2[lang_code] = embedding
        print(f"  ✅ Loaded {lang_code}: {filename}")

# 加载英语变体
en_variants = ['en-us.pth', 'en-br.pth', 'en-au.pth', 'en-india.pth', 'en-newest.pth']
for variant_file in en_variants:
    file_path = f'{ses_dir}/{variant_file}'
    if os.path.exists(file_path):
        variant_name = variant_file.replace('.pth', '')
        embedding = torch.load(file_path, map_location=device)
        if device == "mps":
            embedding = embedding.to(device)
        speaker_embeddings_v2[variant_name] = embedding

print(f"✅ Preloaded {len(speaker_embeddings_v2)} speaker embeddings to {device}")

# 支持的语言
supported_languages_v2 = ['zh', 'en', 'jp', 'es', 'fr', 'kr']
language_names = {
    'zh': 'Chinese',
    'en': 'English', 
    'jp': 'Japanese',
    'es': 'Spanish',
    'fr': 'French',
    'kr': 'Korean'
}

# MeloTTS语言映射
melotts_language_mapping = {
    'zh': 'ZH',
    'en': 'EN',
    'jp': 'JP',
    'es': 'ES',
    'fr': 'FR',
    'kr': 'KR'
}

# 全局MeloTTS模型缓存 - 延迟加载
melotts_models = {}

def get_melotts_model(language):
    """获取或创建MeloTTS模型(延迟加载)"""
    if language not in melotts_models:
        try:
            from melo.api import TTS
            print(f"🔄 Loading MeloTTS model for {language}...")
            start_time = time.time()
            
            # 设置缓存目录
            cache_dir = os.path.expanduser("~/.cache/melotts")
            os.makedirs(cache_dir, exist_ok=True)
            
            model = TTS(language=language, device=device)
            melotts_models[language] = model
            
            load_time = time.time() - start_time
            print(f"✅ MeloTTS {language} model loaded in {load_time:.2f} seconds")
        except Exception as e:
            print(f"❌ Failed to load MeloTTS {language} model: {e}")
            return None
    else:
        print(f"♻️ Using cached MeloTTS {language} model")
    
    return melotts_models.get(language)

def process_audio_m3_fast(prompt, language, audio_file, agree):
    if not agree:
        return ("Please accept the terms of use first.", None, None)
    
    if not audio_file:
        return ("Please upload a reference audio file.", None, None)
    
    total_start_time = time.time()
    text_hint = f"🚀 M3 Max Fast Processing Started\n"
    text_hint += f"Device: {device.upper()}\n"
    text_hint += f"CPU Cores: {multiprocessing.cpu_count()}\n\n"
    
    try:
        # 处理音频文件路径
        if hasattr(audio_file, 'name'):
            audio_name = os.path.splitext(os.path.basename(audio_file.name))[0]
            audio_file_pth = audio_file.name
        else:
            audio_name = os.path.splitext(os.path.basename(audio_file))[0]
            audio_file_pth = audio_file
        
        text_hint += f"Processing: {audio_name}\n"
        text_hint += f"Text: {prompt}\n"
        text_hint += f"Language: {language_names.get(language, language)}\n"
        
        # 语言检测和映射
        language_predicted = langid.classify(prompt)[0].strip()
        language_mapping = {
            'ja': 'jp', 'en': 'en', 'zh': 'zh', 
            'es': 'es', 'fr': 'fr', 'ko': 'kr'
        }
        mapped_language = language_mapping.get(language_predicted, language_predicted)
        
        if mapped_language != language:
            text_hint += f"⚠️ Detected: {language_predicted}, Selected: {language}\n"
        
        # 检查说话者嵌入
        if language not in speaker_embeddings_v2:
            return (f"❌ Speaker embedding for {language} not found", None, audio_file_pth)
        
        source_se = speaker_embeddings_v2[language]
        text_hint += f"✅ Speaker embedding loaded for {language_names.get(language, language)}\n"
        
        # 提取目标说话者嵌入
        text_hint += f"🔍 Extracting target speaker embedding...\n"
        se_start_time = time.time()
        
        target_se, audio_name = se_extractor.get_se(
            audio_file_pth, 
            tone_color_converter_v2, 
            target_dir=output_dir, 
            vad=True
        )
        
        se_time = time.time() - se_start_time
        text_hint += f"✅ Target embedding extracted in {se_time:.2f}s\n"
        
        # MeloTTS语音生成(优化版本)
        try:
            melotts_lang = melotts_language_mapping.get(language, language.upper())
            text_hint += f"🎵 Generating speech with MeloTTS ({melotts_lang})...\n"
            
            tts_start_time = time.time()
            model = get_melotts_model(melotts_lang)
            
            if model is None:
                raise Exception(f"Failed to load MeloTTS model for {melotts_lang}")
            
            speaker_ids = model.hps.data.spk2id
            speaker_key = list(speaker_ids.keys())[0]
            speaker_id = speaker_ids[speaker_key]
            
            # 生成基础语音
            src_path = f'{output_dir}/base_{language}_{audio_name}.wav'
            speed = 1.0
            
            # M3 Max GPU内存优化
            if device == "mps":
                try:
                    torch.backends.mps.empty_cache()
                except AttributeError:
                    pass
            
            model.tts_to_file(prompt, speaker_id, src_path, speed=speed)
            
            tts_time = time.time() - tts_start_time
            text_hint += f"✅ Base speech generated in {tts_time:.2f}s\n"
            
            # 音色转换
            text_hint += f"🎨 Applying tone color conversion...\n"
            convert_start_time = time.time()
            
            final_path = f'{output_dir}/cloned_{language}_{audio_name}.wav'
            
            if device == "mps":
                try:
                    torch.backends.mps.empty_cache()
                except AttributeError:
                    pass
            
            encode_message = "@MyShell"
            tone_color_converter_v2.convert(
                audio_src_path=src_path,
                src_se=source_se,
                tgt_se=target_se,
                output_path=final_path,
                message=encode_message
            )
            
            convert_time = time.time() - convert_start_time
            total_time = time.time() - total_start_time
            
            text_hint += f"✅ Voice cloning completed in {convert_time:.2f}s\n"
            text_hint += f"🎉 Total processing time: {total_time:.2f}s\n"
            text_hint += f"💾 Output: {final_path}\n"
            text_hint += f"🚀 M3 Max Fast Mode: {'GPU (MPS)' if device == 'mps' else 'CPU Optimized'}\n"
            
            return (text_hint, final_path, audio_file_pth)
            
        except Exception as e:
            text_hint += f"❌ Error: {str(e)}\n"
            return (text_hint, None, audio_file_pth)

    except Exception as e:
        return (f"❌ Processing error: {str(e)}", None, None)

# Gradio界面
with gr.Blocks(title="OpenVoice V2 - M3 Max Fast") as demo:
    gr.Markdown("# ⚡ OpenVoice V2 - Apple M3 Max Fast Mode")
    gr.Markdown(f"**Device**: {device.upper()} | **CPU Cores**: {multiprocessing.cpu_count()} | **Mode**: Fast Loading")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(
                label="Text to synthesize",
                placeholder="Enter text in any supported language...",
                lines=3,
                max_lines=5
            )
            
            language = gr.Dropdown(
                choices=[(f"{language_names[lang]} ({lang})", lang) for lang in supported_languages_v2],
                value="jp",
                label="Language"
            )
            
            audio_file = gr.Audio(
                label="Reference Audio",
                type="filepath"
            )
            
            agree = gr.Checkbox(
                label="I agree to the terms of use",
                value=False
            )
            
            submit_btn = gr.Button("⚡ Fast Generate with M3 Max", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(
                label="M3 Max Fast Processing Log",
                lines=12,
                max_lines=20
            )
            
            output_audio = gr.Audio(
                label="Generated Audio"
            )
            
            reference_audio = gr.Audio(
                label="Reference Audio"
            )
    
    # 性能优化提示
    gr.Markdown("""
    ### ⚡ M3 Max Fast Mode Features
    - **Lazy Loading**: MeloTTS models load only when needed
    - **Model Caching**: Avoid repeated downloads
    - **GPU Acceleration**: Apple MPS for neural network inference
    - **Memory Optimization**: Smart GPU memory management
    - **Fast Processing**: Optimized for repeated usage
    """)
    
    # 示例
    gr.Examples(
        examples=[
            ["こんにちは、Apple M3 Max高速モードです。", "jp", "resources/example_reference.mp3", True],
            ["Hello, this is M3 Max fast mode!", "en", "resources/example_reference.mp3", True],
            ["你好,这是M3 Max快速模式。", "zh", "resources/example_reference.mp3", True],
        ],
        inputs=[prompt, language, audio_file, agree],
        label="Fast Mode Examples"
    )
    
    submit_btn.click(
        fn=process_audio_m3_fast,
        inputs=[prompt, language, audio_file, agree],
        outputs=[output_text, output_audio, reference_audio]
    )

if __name__ == "__main__":
    demo.launch(share=args.share, server_port=7866)