transcription/tests/test_native_engine.py

178 lines
6.8 KiB
Python
Raw Normal View History

"""Real-data tests for native RAG engine (no mocks, in-process).
Использует sentence-transformers реальную модель (~50 MB скачивается при первом
запуске). Время выполнения: ~10 сек (cold) + ~1 сек на warm-кейсы.
"""
import tempfile
import unittest
from pathlib import Path
from src.rag.engine import Engine, get_or_create_engine, invalidate_engine
from src.rag.engine.chunker import chunk_text
from src.rag.engine.bm25 import bm25_search
from src.rag.engine.vector import vector_search
from src.rag.engine.hybrid import rrf_fuse
SAMPLE_DOCS = {
"plan.md": (
"# План 3-го этажа\n\n"
"План 3-го этажа жилого дома. Оси: А, Б, В, Г. Размеры между осями А и Б: 5400 мм.\n"
"Квартиры: 301, 302, 303. Кухни объединены с гостиными. Санузлы раздельные.\n\n"
"## Отделка\n"
"Стены — штукатурка, покраска. Полы — ламинат. Потолки — гипсокартон.\n"
),
"auth.md": (
"# Авторизация\n\n"
"Авторизация работает через JWT-токены с TTL 24 часа.\n"
"Refresh-токен живёт 30 дней. Логика валидации в middleware.\n"
"Сессии хранятся в Redis, ключ — sha256 от user_id + jti.\n"
),
"schedule.md": (
"# График работ\n\n"
"Строительство начинается 1 июня 2026. Окончание — 30 ноября 2027.\n"
"Этапы: фундамент → стены → кровля → MEP → отделка.\n"
),
}
class ChunkerTestCase(unittest.TestCase):
def test_chunks_short_text_returns_one(self):
chunks = chunk_text("Привет, мир.", max_chars=100)
self.assertEqual(len(chunks), 1)
self.assertIn("Привет", chunks[0].text)
def test_chunks_long_text_splits_with_overlap(self):
text = "Абзац.\n\n" * 200
chunks = chunk_text(text, max_chars=400, overlap=60)
self.assertGreater(len(chunks), 1)
for c in chunks:
self.assertLessEqual(len(c.text), 1000) # с overlap
def test_empty_text_returns_empty(self):
self.assertEqual(chunk_text(""), [])
self.assertEqual(chunk_text(" \n\n "), [])
class EngineIngestTestCase(unittest.TestCase):
def setUp(self):
self._tmp = tempfile.TemporaryDirectory()
self.tmp = Path(self._tmp.name)
self.engines = []
def _make_engine(self, name: str = "coll1") -> Engine:
eng = Engine.from_paths(self.tmp / name)
self.engines.append(eng)
return eng
def tearDown(self):
for eng in self.engines:
eng.close()
for db_path in self.tmp.rglob("index.sqlite"):
invalidate_engine(db_path.parent)
import time
time.sleep(0.05)
try:
self._tmp.cleanup()
except (PermissionError, OSError):
pass
def test_index_text_returns_chunks_and_vectors(self):
eng = self._make_engine()
result = eng.index_text(SAMPLE_DOCS["plan.md"], source_path="plan.md")
self.assertFalse(result.skipped)
self.assertGreaterEqual(result.chunks_indexed, 1)
self.assertEqual(result.vectors_indexed, result.chunks_indexed)
status = eng.status()
self.assertGreaterEqual(status["chunks"], 1)
self.assertIn(status["engine"], ("sqlite-vec", "numpy"))
def test_index_text_is_idempotent(self):
eng = self._make_engine()
r1 = eng.index_text(SAMPLE_DOCS["plan.md"], source_path="plan.md")
r2 = eng.index_text(SAMPLE_DOCS["plan.md"], source_path="plan.md")
self.assertFalse(r1.skipped)
self.assertTrue(r2.skipped)
def test_index_file_change_detected(self):
eng = self._make_engine()
f = self.tmp / "x.md"
f.write_text("first version", encoding="utf-8")
r1 = eng.index_file(f)
self.assertFalse(r1.skipped)
f.write_text("second version with new content", encoding="utf-8")
r2 = eng.index_file(f)
self.assertFalse(r2.skipped)
def test_index_many_files(self):
eng = self._make_engine()
for name, text in SAMPLE_DOCS.items():
eng.index_text(text, source_path=name)
status = eng.status()
self.assertEqual(status["chunks"], 3)
self.assertEqual(status["files"], 3)
class EngineSearchTestCase(unittest.TestCase):
"""Реальные тесты с warm-embedding."""
@classmethod
def setUpClass(cls):
cls._tmp = tempfile.TemporaryDirectory()
cls.tmp = Path(cls._tmp.name)
cls.eng = Engine.from_paths(cls.tmp / "coll_search")
for name, text in SAMPLE_DOCS.items():
cls.eng.index_text(text, source_path=name)
cls.eng.warmup()
@classmethod
def tearDownClass(cls):
cls.eng.close()
invalidate_engine(cls.tmp / "coll_search")
import time
time.sleep(0.05)
try:
cls._tmp.cleanup()
except (PermissionError, OSError):
pass
def test_bm25_search_finds_keywords(self):
hits = self.eng.search("авторизация JWT")
self.assertGreater(len(hits), 0)
self.assertEqual(hits[0].file_path, "auth.md")
def test_bm25_search_empty_query(self):
self.assertEqual(self.eng.search(""), [])
self.assertEqual(self.eng.search(" "), [])
def test_vector_search_finds_semantic(self):
# "как устроена авторизация" — семантически близко к "авторизация"
hits = self.eng.vsearch("как устроена авторизация")
self.assertGreater(len(hits), 0)
# top hit должен быть auth.md или schedule.md (есть слово "логика")
paths = [h.file_path for h in hits]
self.assertTrue(any("auth" in p for p in paths))
def test_hybrid_query_uses_rrf(self):
hits = self.eng.query("квартиры 3 этаж", limit=3)
self.assertGreater(len(hits), 0)
self.assertEqual(hits[0].file_path, "plan.md")
def test_get_returns_full_document(self):
hits = self.eng.search("авторизация")
full = self.eng.get(hits[0].doc_id)
self.assertIn("JWT", full)
self.assertIn("TTL", full)
def test_status_reports_engine(self):
status = self.eng.status()
self.assertEqual(status["files"], 3)
self.assertGreaterEqual(status["chunks"], 3)
self.assertEqual(status["embedding_dim"], 384)
self.assertTrue(status["embedding_loaded"])
if __name__ == "__main__":
unittest.main()