193 lines
7.3 KiB
Python
193 lines
7.3 KiB
Python
|
|
# app/crud.py
|
||
|
|
from sqlalchemy.orm import Session, joinedload
|
||
|
|
from sqlalchemy import func
|
||
|
|
from typing import List, Optional
|
||
|
|
import app.models as models
|
||
|
|
import app.schemas as schemas
|
||
|
|
|
||
|
|
|
||
|
|
# ---------- Project ----------
|
||
|
|
def create_project(db: Session, pdf_filename: str, name: Optional[str] = None) -> models.Project:
|
||
|
|
project = models.Project(
|
||
|
|
name=name or pdf_filename,
|
||
|
|
pdf_filename=pdf_filename,
|
||
|
|
status="uploaded"
|
||
|
|
)
|
||
|
|
db.add(project)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(project)
|
||
|
|
return project
|
||
|
|
|
||
|
|
def get_project(db: Session, project_id: int) -> Optional[models.Project]:
|
||
|
|
return db.query(models.Project).options(
|
||
|
|
joinedload(models.Project.pages),
|
||
|
|
joinedload(models.Project.issues).joinedload(models.Issue.feedback)
|
||
|
|
).filter(models.Project.id == project_id).first()
|
||
|
|
|
||
|
|
def get_projects(db: Session, skip: int = 0, limit: int = 100) -> List[models.Project]:
|
||
|
|
return db.query(models.Project).options(
|
||
|
|
joinedload(models.Project.pages),
|
||
|
|
joinedload(models.Project.issues).joinedload(models.Issue.feedback)
|
||
|
|
).order_by(models.Project.created_at.desc()).offset(skip).limit(limit).all()
|
||
|
|
|
||
|
|
def update_project_status(db: Session, project_id: int, status: str, error_message: str = None, output_folder: str = None):
|
||
|
|
project = get_project(db, project_id)
|
||
|
|
if project:
|
||
|
|
project.status = status
|
||
|
|
if error_message is not None:
|
||
|
|
project.error_message = error_message
|
||
|
|
# Очищаем ошибку при новом запуске или успехе
|
||
|
|
if status in ("processing", "completed") and error_message is None:
|
||
|
|
project.error_message = None
|
||
|
|
if output_folder:
|
||
|
|
project.output_folder = output_folder
|
||
|
|
if status == "completed":
|
||
|
|
from datetime import datetime
|
||
|
|
project.completed_at = datetime.utcnow()
|
||
|
|
db.commit()
|
||
|
|
db.refresh(project)
|
||
|
|
return project
|
||
|
|
|
||
|
|
|
||
|
|
# ---------- Page ----------
|
||
|
|
def create_page(db: Session, project_id: int, page_number: int, **kwargs) -> models.Page:
|
||
|
|
page = models.Page(project_id=project_id, page_number=page_number, **kwargs)
|
||
|
|
db.add(page)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(page)
|
||
|
|
return page
|
||
|
|
|
||
|
|
def get_page_by_number(db: Session, project_id: int, page_number: int) -> Optional[models.Page]:
|
||
|
|
return db.query(models.Page).filter(
|
||
|
|
models.Page.project_id == project_id,
|
||
|
|
models.Page.page_number == page_number
|
||
|
|
).first()
|
||
|
|
|
||
|
|
|
||
|
|
# ---------- Issue ----------
|
||
|
|
def create_issue(db: Session, project_id: int, page_id: Optional[int], **kwargs) -> models.Issue:
|
||
|
|
issue = models.Issue(project_id=project_id, page_id=page_id, **kwargs)
|
||
|
|
db.add(issue)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(issue)
|
||
|
|
return issue
|
||
|
|
|
||
|
|
def get_issues(db: Session, project_id: Optional[int] = None, page_id: Optional[int] = None,
|
||
|
|
severity: Optional[str] = None, issue_type: Optional[str] = None,
|
||
|
|
has_feedback: Optional[bool] = None, skip: int = 0, limit: int = 1000):
|
||
|
|
query = db.query(models.Issue)
|
||
|
|
if project_id:
|
||
|
|
query = query.filter(models.Issue.project_id == project_id)
|
||
|
|
if page_id:
|
||
|
|
query = query.filter(models.Issue.page_id == page_id)
|
||
|
|
if severity:
|
||
|
|
query = query.filter(models.Issue.severity == severity)
|
||
|
|
if issue_type:
|
||
|
|
query = query.filter(models.Issue.issue_type == issue_type)
|
||
|
|
if has_feedback is not None:
|
||
|
|
if has_feedback:
|
||
|
|
query = query.filter(models.Issue.feedback != None)
|
||
|
|
else:
|
||
|
|
query = query.filter(models.Issue.feedback == None)
|
||
|
|
return query.order_by(models.Issue.created_at.desc()).offset(skip).limit(limit).all()
|
||
|
|
|
||
|
|
def get_issue(db: Session, issue_id: int) -> Optional[models.Issue]:
|
||
|
|
return db.query(models.Issue).filter(models.Issue.id == issue_id).first()
|
||
|
|
|
||
|
|
|
||
|
|
# ---------- Feedback ----------
|
||
|
|
def create_feedback(db: Session, feedback: schemas.FeedbackCreate) -> models.Feedback:
|
||
|
|
# Удалить старый feedback если есть
|
||
|
|
existing = db.query(models.Feedback).filter(models.Feedback.issue_id == feedback.issue_id).first()
|
||
|
|
if existing:
|
||
|
|
db.delete(existing)
|
||
|
|
db.commit()
|
||
|
|
|
||
|
|
db_feedback = models.Feedback(**feedback.dict())
|
||
|
|
db.add(db_feedback)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(db_feedback)
|
||
|
|
return db_feedback
|
||
|
|
|
||
|
|
def get_feedback_stats(db: Session, project_id: Optional[int] = None):
|
||
|
|
query = db.query(models.Feedback)
|
||
|
|
if project_id:
|
||
|
|
query = query.join(models.Issue).filter(models.Issue.project_id == project_id)
|
||
|
|
|
||
|
|
total = query.count()
|
||
|
|
true_positive = query.filter(models.Feedback.is_true_positive == True).count()
|
||
|
|
false_positive = query.filter(models.Feedback.is_true_positive == False).count()
|
||
|
|
unreviewed = query.filter(models.Feedback.is_true_positive == None).count()
|
||
|
|
|
||
|
|
accuracy = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else None
|
||
|
|
|
||
|
|
return {
|
||
|
|
"total": total,
|
||
|
|
"true_positive": true_positive,
|
||
|
|
"false_positive": false_positive,
|
||
|
|
"unreviewed": unreviewed,
|
||
|
|
"accuracy_estimate": round(accuracy, 3) if accuracy else None
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# ---------- Stats ----------
|
||
|
|
def get_stats(db: Session) -> schemas.StatsResponse:
|
||
|
|
total_projects = db.query(models.Project).count()
|
||
|
|
total_issues = db.query(models.Issue).count()
|
||
|
|
|
||
|
|
# По типам
|
||
|
|
issue_types = db.query(models.Issue.issue_type, func.count(models.Issue.id)).group_by(models.Issue.issue_type).all()
|
||
|
|
issues_by_type = {t[0]: t[1] for t in issue_types}
|
||
|
|
|
||
|
|
# Feedback
|
||
|
|
fb_stats = get_feedback_stats(db)
|
||
|
|
|
||
|
|
return schemas.StatsResponse(
|
||
|
|
total_projects=total_projects,
|
||
|
|
total_issues=total_issues,
|
||
|
|
issues_by_type=issues_by_type,
|
||
|
|
feedback_stats={
|
||
|
|
"true_positive": fb_stats["true_positive"],
|
||
|
|
"false_positive": fb_stats["false_positive"],
|
||
|
|
"unreviewed": fb_stats["unreviewed"]
|
||
|
|
},
|
||
|
|
accuracy_estimate=fb_stats["accuracy_estimate"]
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------- Training Data ----------
|
||
|
|
def export_training_data(db: Session, project_id: Optional[int] = None,
|
||
|
|
only_labeled: bool = True) -> List[schemas.TrainingSample]:
|
||
|
|
"""Экспорт данных для обучения ML-модели."""
|
||
|
|
query = db.query(models.Issue)
|
||
|
|
if project_id:
|
||
|
|
query = query.filter(models.Issue.project_id == project_id)
|
||
|
|
if only_labeled:
|
||
|
|
query = query.join(models.Feedback)
|
||
|
|
|
||
|
|
issues = query.all()
|
||
|
|
samples = []
|
||
|
|
|
||
|
|
for issue in issues:
|
||
|
|
sample = schemas.TrainingSample(
|
||
|
|
issue_id=issue.id,
|
||
|
|
issue_type=issue.issue_type,
|
||
|
|
severity=issue.severity,
|
||
|
|
message=issue.message,
|
||
|
|
bbox={
|
||
|
|
"x1": issue.bbox_x1,
|
||
|
|
"y1": issue.bbox_y1,
|
||
|
|
"x2": issue.bbox_x2,
|
||
|
|
"y2": issue.bbox_y2
|
||
|
|
},
|
||
|
|
dimension_text=issue.dimension_text,
|
||
|
|
confidence=issue.confidence,
|
||
|
|
page_number=issue.page.page_number if issue.page else None,
|
||
|
|
is_true_positive=issue.feedback.is_true_positive if issue.feedback else None,
|
||
|
|
image_path=issue.page.png_path if issue.page else None,
|
||
|
|
label=f"{'good' if issue.feedback and issue.feedback.is_true_positive else 'bad'}_{issue.issue_type.lower()}" if issue.feedback else None
|
||
|
|
)
|
||
|
|
samples.append(sample)
|
||
|
|
|
||
|
|
return samples
|