72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
"""Скрипт предзагрузки моделей для Docker образа.
|
|
|
|
Запускается во время сборки Docker образа.
|
|
Загружает все необходимые модели, чтобы при старте контейнера
|
|
не нужно было ждать скачивания.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
# Устанавливаем токен HuggingFace из env или placeholder
|
|
os.environ.setdefault("HF_TOKEN", os.environ.get("HF_TOKEN", ""))
|
|
|
|
import whisperx
|
|
|
|
|
|
def download_whisper_model(model_name="large-v3"):
|
|
"""Загружает модель Whisper."""
|
|
print(f"[Download] Whisper модель: {model_name}")
|
|
device = "cpu"
|
|
compute_type = "int8"
|
|
model = whisperx.load_model(model_name, device, compute_type=compute_type)
|
|
print(f"[Download] Whisper {model_name} готова")
|
|
del model
|
|
|
|
|
|
def download_alignment_model(language="ru"):
|
|
"""Загружает alignment модель для языка."""
|
|
print(f"[Download] Alignment модель для {language}")
|
|
device = "cpu"
|
|
model_a, metadata = whisperx.load_align_model(language_code=language, device=device)
|
|
print(f"[Download] Alignment {language} готова")
|
|
del model_a
|
|
|
|
|
|
def download_diarization_model():
|
|
"""Загружает модель диаризации."""
|
|
print("[Download] Модель диаризации")
|
|
token = os.environ.get("HF_TOKEN", "")
|
|
if not token:
|
|
print("[Warning] HF_TOKEN не установлен — модель диаризации не загружена!")
|
|
print("[Warning] При старте контейнера установите HF_TOKEN через env.")
|
|
return
|
|
from whisperx.diarize import DiarizationPipeline
|
|
device = "cpu"
|
|
diarize_model = DiarizationPipeline(token=token, device=device)
|
|
print("[Download] Диаризация готова")
|
|
del diarize_model
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Предзагрузка моделей для Docker образа")
|
|
print("=" * 60)
|
|
|
|
download_whisper_model("large-v3")
|
|
download_alignment_model("ru")
|
|
|
|
# Диаризация требует токен, поэтому опционально
|
|
try:
|
|
download_diarization_model()
|
|
except Exception as e:
|
|
print(f"[Warning] Диаризация не загружена: {e}")
|
|
|
|
print("=" * 60)
|
|
print("Предзагрузка завершена")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|