155 lines
4.3 KiB
Python
155 lines
4.3 KiB
Python
from typing import Annotated, Union, Literal
|
|
from fastapi import FastAPI, HTTPException, Path, Depends
|
|
from fastapi.responses import JSONResponse
|
|
from .mode_handler import (
|
|
simple_info_mode,
|
|
waveform_mode,
|
|
spectrogram_mode,
|
|
vowel_space_mode,
|
|
transcription_mode,
|
|
error_rate_mode,
|
|
convert_to_wav,
|
|
)
|
|
from .response_examples import (
|
|
signal_modes_response_examples,
|
|
transcription_response_examples,
|
|
)
|
|
from .transcription.transcription import get_transcription
|
|
from .data_objects import (
|
|
SimpleInfoResponse,
|
|
VowelSpaceResponse,
|
|
TranscriptionSegment,
|
|
ErrorRateResponse,
|
|
)
|
|
from .database import Database
|
|
import orjson
|
|
import json
|
|
import os
|
|
from typing import Any
|
|
|
|
|
|
def get_db(): # pragma: no cover
|
|
db = None
|
|
try:
|
|
user = os.getenv("POSTGRES_USER", "user")
|
|
password = os.getenv("POSTGRES_PASSWORD", "password")
|
|
host = os.getenv("POSTGRES_HOST", "localhost")
|
|
port = os.getenv("POSTGRES_PORT", "5432")
|
|
dbname = os.getenv("POSTGRES_DB", "postgres")
|
|
|
|
db = Database(user, password, host, port, dbname)
|
|
db.connection()
|
|
yield db
|
|
finally:
|
|
if db is not None:
|
|
db.close()
|
|
|
|
|
|
class ORJSONResponse(JSONResponse):
|
|
"""
|
|
Custom JSONResponse class using ORJSON to handle nan's.
|
|
"""
|
|
|
|
media_type = "application/json"
|
|
|
|
def render(self, content: Any) -> bytes:
|
|
return orjson.dumps(content)
|
|
|
|
|
|
app: FastAPI = FastAPI(default_response_class=ORJSONResponse, root_path="/api")
|
|
|
|
|
|
@app.get(
|
|
"/signals/modes/{mode}",
|
|
response_model=Union[
|
|
None,
|
|
SimpleInfoResponse,
|
|
VowelSpaceResponse,
|
|
list[list[TranscriptionSegment]],
|
|
ErrorRateResponse,
|
|
],
|
|
responses=signal_modes_response_examples,
|
|
)
|
|
async def analyze_signal_mode(
|
|
mode: Annotated[
|
|
Literal[
|
|
"simple-info",
|
|
"spectrogram",
|
|
"waveform",
|
|
"vowel-space",
|
|
"transcription",
|
|
"error-rate",
|
|
],
|
|
Path(title="The analysis mode"),
|
|
],
|
|
fileState,
|
|
database=Depends(get_db),
|
|
):
|
|
"""
|
|
Analyze an audio signal in different modes.
|
|
|
|
This endpoint fetches an audio file from the database and performs the analysis based on the specified mode.
|
|
|
|
Parameters:
|
|
- mode (str): The analysis mode (e.g., "simple-info", "spectrogram", "wave-form", "vowel-space", "transcription", "error-rate").
|
|
- fileState (dict): The important state data of the file
|
|
|
|
Returns:
|
|
- dict: The result of the analysis based on the selected mode.
|
|
|
|
Raises:
|
|
- HTTPException: If the mode is not found or input data is invalid.
|
|
"""
|
|
db_session = database
|
|
# db_session = next(database)
|
|
fileState = json.loads(fileState)
|
|
if mode == "simple-info":
|
|
return simple_info_mode(db_session, fileState)
|
|
if mode == "spectrogram":
|
|
return spectrogram_mode(db_session, fileState)
|
|
if mode == "waveform":
|
|
return waveform_mode(db_session, fileState)
|
|
if mode == "vowel-space":
|
|
return vowel_space_mode(db_session, fileState)
|
|
if mode == "transcription":
|
|
return transcription_mode(db_session, fileState)
|
|
if mode == "error-rate":
|
|
return error_rate_mode(db_session, fileState)
|
|
|
|
|
|
@app.get(
|
|
"/transcription/{model}/{file_id}",
|
|
response_model=list[TranscriptionSegment],
|
|
responses=transcription_response_examples,
|
|
)
|
|
async def transcribe_file(
|
|
model: Annotated[str, Path(title="The transcription model")],
|
|
file_id: Annotated[str, Path(title="The ID of the file")],
|
|
database=Depends(get_db),
|
|
):
|
|
"""
|
|
Transcribe an audio file.
|
|
|
|
This endpoint transcribes an audio file using the specified model.
|
|
|
|
Parameters:
|
|
- model (str): The transcription model to use.
|
|
- file_id (str): The ID of the file to transcribe.
|
|
|
|
Returns:
|
|
- list: A list of dictionaries with keys 'start', 'end' and 'value' containing the transcription of the audio file.
|
|
|
|
Raises:
|
|
- HTTPException: If the file is not found or an error occurs during transcription or storing the transcription.
|
|
"""
|
|
db_session = database
|
|
# db_session = next(database)
|
|
try:
|
|
file = db_session.fetch_file(file_id)
|
|
except Exception as _:
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
file["data"] = convert_to_wav(file["data"])
|
|
|
|
transcription = get_transcription(model, file)
|
|
return transcription
|