Skip to content

Commit

Permalink
Apply resnet50 trained model
Browse files Browse the repository at this point in the history
  • Loading branch information
KarinaTiurina committed Jan 25, 2024
1 parent c8979e4 commit fe43d98
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions backend/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import io
import base64
from core.models.resnet18_7c import Resnet18_7C
from core.models.resnet_7c import Resnet_7C

app = FastAPI()

Expand All @@ -27,7 +27,7 @@
allow_headers=["*"],
)

Resnet18_7CModel = Resnet18_7C() # init trained model
Resnet_7CModel = Resnet_7C() # init trained model


@app.get("/")
Expand Down Expand Up @@ -57,7 +57,7 @@ async def detect_age_single(websocket: WebSocket):
data = await websocket.receive_text()
decoded_data = base64.b64decode(data)
frame = cv2.imdecode(np.frombuffer(decoded_data, dtype=np.uint8), 1)
detect_faces(frame, Resnet18_7CModel)
detect_faces(frame, Resnet_7CModel)
_, encoded_frame = cv2.imencode('.jpg', frame)
image = base64.b64encode(encoded_frame.tobytes()).decode('utf-8')
await websocket.send_text(image)
Expand All @@ -73,7 +73,7 @@ def detect_age_multiple(files: List[UploadFile] = File(...)):
content = file.file.read()
image_array = cv2.imdecode(np.frombuffer(content, np.uint8), -1)

detected_faces = detect_faces(image_array, Resnet18_7CModel)
detected_faces = detect_faces(image_array, Resnet_7CModel)
# if not detected_faces['faces'].any():
# raise HTTPException(status_code=400, detail="No faces detected in the image.")

Expand Down Expand Up @@ -107,7 +107,7 @@ async def detect_age_video(file: UploadFile = File(...)):
with open(file_path, "wb") as video_file:
shutil.copyfileobj(file.file, video_file)

frames, frame_rate = process_video(file_path, detect_faces, Resnet18_7CModel)
frames, frame_rate = process_video(file_path, detect_faces, Resnet_7CModel)

video_bytes = generate_video(frames, temp_dir, frame_rate)

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torchvision import transforms
from torchvision.models import ResNet50_Weights

class Resnet18_7C:
def __init__(self, model_path = 'assets/models/resnet18_7c_aug_lr01_25e_step5.pt'):
class Resnet_7C:
def __init__(self, model_path = 'assets/models/resnet50-aug-dropout-lr01.pt'):
self.classes = {
0: '0 - 2',
1: '3 - 9',
Expand Down

0 comments on commit fe43d98

Please sign in to comment.