cse2000-software-project/kernel/spectral/main.py
2024-06-10 17:17:37 +02:00

155 lines
4.3 KiB
Python

from typing import Annotated, Union, Literal
from fastapi import FastAPI, HTTPException, Path, Depends
from fastapi.responses import JSONResponse
from .mode_handler import (
simple_info_mode,
waveform_mode,
spectrogram_mode,
vowel_space_mode,
transcription_mode,
error_rate_mode,
convert_to_wav,
)
from .response_examples import (
signal_modes_response_examples,
transcription_response_examples,
)
from .transcription.transcription import get_transcription
from .data_objects import (
SimpleInfoResponse,
VowelSpaceResponse,
TranscriptionSegment,
ErrorRateResponse,
)
from .database import Database
import orjson
import json
import os
from typing import Any
def get_db(): # pragma: no cover
db = None
try:
user = os.getenv("POSTGRES_USER", "user")
password = os.getenv("POSTGRES_PASSWORD", "password")
host = os.getenv("POSTGRES_HOST", "localhost")
port = os.getenv("POSTGRES_PORT", "5432")
dbname = os.getenv("POSTGRES_DB", "postgres")
db = Database(user, password, host, port, dbname)
db.connection()
yield db
finally:
if db is not None:
db.close()
class ORJSONResponse(JSONResponse):
"""
Custom JSONResponse class using ORJSON to handle nan's.
"""
media_type = "application/json"
def render(self, content: Any) -> bytes:
return orjson.dumps(content)
app: FastAPI = FastAPI(default_response_class=ORJSONResponse, root_path="/api")
@app.get(
"/signals/modes/{mode}",
response_model=Union[
None,
SimpleInfoResponse,
VowelSpaceResponse,
list[list[TranscriptionSegment]],
ErrorRateResponse,
],
responses=signal_modes_response_examples,
)
async def analyze_signal_mode(
mode: Annotated[
Literal[
"simple-info",
"spectrogram",
"waveform",
"vowel-space",
"transcription",
"error-rate",
],
Path(title="The analysis mode"),
],
fileState,
database=Depends(get_db),
):
"""
Analyze an audio signal in different modes.
This endpoint fetches an audio file from the database and performs the analysis based on the specified mode.
Parameters:
- mode (str): The analysis mode (e.g., "simple-info", "spectrogram", "wave-form", "vowel-space", "transcription", "error-rate").
- fileState (dict): The important state data of the file
Returns:
- dict: The result of the analysis based on the selected mode.
Raises:
- HTTPException: If the mode is not found or input data is invalid.
"""
db_session = database
# db_session = next(database)
fileState = json.loads(fileState)
if mode == "simple-info":
return simple_info_mode(db_session, fileState)
if mode == "spectrogram":
return spectrogram_mode(db_session, fileState)
if mode == "waveform":
return waveform_mode(db_session, fileState)
if mode == "vowel-space":
return vowel_space_mode(db_session, fileState)
if mode == "transcription":
return transcription_mode(db_session, fileState)
if mode == "error-rate":
return error_rate_mode(db_session, fileState)
@app.get(
"/transcription/{model}/{file_id}",
response_model=list[TranscriptionSegment],
responses=transcription_response_examples,
)
async def transcribe_file(
model: Annotated[str, Path(title="The transcription model")],
file_id: Annotated[str, Path(title="The ID of the file")],
database=Depends(get_db),
):
"""
Transcribe an audio file.
This endpoint transcribes an audio file using the specified model.
Parameters:
- model (str): The transcription model to use.
- file_id (str): The ID of the file to transcribe.
Returns:
- list: A list of dictionaries with keys 'start', 'end' and 'value' containing the transcription of the audio file.
Raises:
- HTTPException: If the file is not found or an error occurs during transcription or storing the transcription.
"""
db_session = database
# db_session = next(database)
try:
file = db_session.fetch_file(file_id)
except Exception as _:
raise HTTPException(status_code=404, detail="File not found")
file["data"] = convert_to_wav(file["data"])
transcription = get_transcription(model, file)
return transcription