Загружает файлы, супер

This commit is contained in:
maksim 2024-12-17 01:01:44 +04:00
parent 93e2ffe021
commit 482daf29da
4 changed files with 128 additions and 67 deletions

View File

@ -1,30 +1,57 @@
from typing import Optional from datetime import datetime
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy import ForeignKey
engine = create_async_engine("sqlite+aiosqlite:///tasks.db") engine = create_async_engine("sqlite+aiosqlite:///tasks.db")
new_session = async_sessionmaker(engine, expire_on_commit=False) new_session = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
class Model(DeclarativeBase):
pass pass
class User(Base):
class TaskOrm(Model): __tablename__ = "users"
__tablename__ = "tasks"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] username: Mapped[str] = mapped_column(unique=True)
description: Mapped[Optional[str]] password_hash: Mapped[str]
# Связи
csv_files = relationship("CSVFile", back_populates="user")
models = relationship("H5Model", back_populates="user")
class CSVFile(Base):
__tablename__ = "csv_files"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
file_path: Mapped[str]
uploaded_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
user = relationship("User", back_populates="csv_files")
class H5Model(Base):
__tablename__ = "h5_models"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
model_path: Mapped[str]
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
user = relationship("User", back_populates="models")
statistics = relationship("ModelStatistics", back_populates="model")
class ModelStatistics(Base):
__tablename__ = "model_statistics"
id: Mapped[int] = mapped_column(primary_key=True)
model_id: Mapped[int] = mapped_column(ForeignKey("h5_models.id"))
accuracy: Mapped[float]
loss: Mapped[float]
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
model = relationship("H5Model", back_populates="statistics")
async def create_tables(): async def create_tables():
# https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Model.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
async def delete_tables(): async def delete_tables():
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Model.metadata.drop_all) await conn.run_sync(Base.metadata.drop_all)

View File

@ -1,26 +1,36 @@
from database import new_session, User, CSVFile, H5Model, ModelStatistics
from schemas import UserCreate, CSVFileUpload, H5ModelCreate, ModelStatisticsCreate
from sqlalchemy import select from sqlalchemy import select
from database import new_session, TaskOrm class UserRepository:
from schemas import STaskAdd, STask @staticmethod
async def create_user(data: UserCreate):
class TaskRepository:
@classmethod
async def add_one(cls, data: STaskAdd) -> int:
async with new_session() as session: async with new_session() as session:
task_dict = data.model_dump() user = User(username=data.username, password_hash=data.password)
session.add(user)
task = TaskOrm(**task_dict)
session.add(task)
await session.flush()
await session.commit() await session.commit()
return task.id return user.id
@classmethod class CSVFileRepository:
async def find_all(cls) -> list[STask]: @staticmethod
async def upload_file(user_id: int, file_path: str):
async with new_session() as session: async with new_session() as session:
query = select(TaskOrm) csv_file = CSVFile(user_id=user_id, file_path=file_path)
result = await session.execute(query) session.add(csv_file)
task_models = result.scalars().all() await session.commit()
task_schemas = [STask.model_validate(task_model) for task_model in task_models]
return task_schemas class H5ModelRepository:
@staticmethod
async def add_model(user_id: int, model_path: str):
async with new_session() as session:
model = H5Model(user_id=user_id, model_path=model_path)
session.add(model)
await session.commit()
class ModelStatisticsRepository:
@staticmethod
async def add_statistics(data: ModelStatisticsCreate):
async with new_session() as session:
stats = ModelStatistics(**data.model_dump())
session.add(stats)
await session.commit()

View File

@ -1,25 +1,45 @@
from typing import Annotated from fastapi import APIRouter, UploadFile, File, Form
from repository import UserRepository, CSVFileRepository, H5ModelRepository, ModelStatisticsRepository
from schemas import UserCreate, ModelStatisticsCreate
import shutil
import os
from fastapi import APIRouter, Depends router = APIRouter()
from repository import TaskRepository UPLOAD_FOLDER_CSV = "uploads/csv"
from schemas import STaskAdd, STask, STaskId os.makedirs(UPLOAD_FOLDER_CSV, exist_ok=True)
router = APIRouter( UPLOAD_FOLDER_MODELS = "uploads/models"
prefix="/tasks", os.makedirs(UPLOAD_FOLDER_MODELS, exist_ok=True)
tags=["Таски"],
)
# Регистрация пользователя
@router.post("/users/")
async def create_user(user: UserCreate):
user_id = await UserRepository.create_user(user)
return {"user_id": user_id}
@router.post("") # Загрузка CSV файла
async def add_task( @router.post("/upload/csv/")
task: Annotated[STaskAdd, Depends()], async def upload_csv(user_id: int = Form(...), file: UploadFile = File(...)):
) -> STaskId: file_path = os.path.join(UPLOAD_FOLDER_CSV, file.filename)
task_id = await TaskRepository.add_one(task) with open(file_path, "wb") as buffer:
return {"ok": True, "task_id": task_id} shutil.copyfileobj(file.file, buffer)
await CSVFileRepository.upload_file(user_id=user_id, file_path=file_path)
return {"message": "CSV файл загружен", "file_path": file_path}
@router.get("") # Загрузка H5 модели
async def get_tasks() -> list[STask]: @router.post("/upload/h5/")
tasks = await TaskRepository.find_all() async def upload_h5_model(user_id: int = Form(...), file: UploadFile = File(...)):
return tasks file_path = os.path.join(UPLOAD_FOLDER_MODELS, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
await H5ModelRepository.add_model(user_id=user_id, model_path=file_path)
return {"message": "H5 модель загружена", "file_path": file_path}
# Добавление статистики модели
@router.post("/models/statistics/")
async def add_model_statistics(stats: ModelStatisticsCreate):
await ModelStatisticsRepository.add_statistics(stats)
return {"message": "Статистика модели сохранена"}

View File

@ -1,19 +1,23 @@
from pydantic import BaseModel
from datetime import datetime
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict # UserCreate: для создания пользователя
class UserCreate(BaseModel):
username: str
password: str
# CSVFileUpload: для загрузки CSV файла
class CSVFileUpload(BaseModel):
file_path: str
class STaskAdd(BaseModel): # H5ModelCreate: для добавления модели
name: str class H5ModelCreate(BaseModel):
description: Optional[str] = None model_path: str
# ModelStatisticsCreate: для сохранения статистики модели
class STask(STaskAdd): class ModelStatisticsCreate(BaseModel):
id: int model_id: int
accuracy: float
model_config = ConfigDict(from_attributes=True) loss: float
created_at: Optional[datetime] = None
class STaskId(BaseModel):
ok: bool = True
task_id: int