chartbastan/backend/tests/test_prediction_api.py
2026-02-01 09:31:38 +01:00

538 lines
19 KiB
Python

"""
Unit tests for Prediction API endpoints.
"""
import pytest
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from app.database import Base, engine, SessionLocal
from app.main import app
from app.models.match import Match
from app.models.prediction import Prediction
@pytest.fixture(scope="function")
def db_session():
"""Create a fresh database session for each test."""
Base.metadata.create_all(bind=engine)
session = SessionLocal()
try:
yield session
session.rollback()
finally:
session.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture
def client():
"""Create a test client."""
return TestClient(app)
@pytest.fixture
def sample_match(db_session: Session) -> Match:
"""Create a sample match for testing."""
match = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
return match
@pytest.fixture
def sample_prediction(db_session: Session, sample_match: Match) -> Prediction:
"""Create a sample prediction for testing."""
prediction = Prediction(
match_id=sample_match.id,
energy_score="high",
confidence="75.0%",
predicted_winner="PSG",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction)
db_session.commit()
db_session.refresh(prediction)
return prediction
class TestCreatePredictionEndpoint:
"""Test POST /api/v1/predictions/matches/{match_id}/predict endpoint."""
def test_create_prediction_success(self, client: TestClient, sample_match: Match):
"""Test creating a prediction successfully."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 65.0,
"away_energy": 45.0,
"energy_score_label": "high"
}
)
assert response.status_code == 201
data = response.json()
assert data["match_id"] == sample_match.id
assert data["energy_score"] == "high"
assert "%" in data["confidence"]
assert data["predicted_winner"] == "PSG"
assert "id" in data
assert "created_at" in data
def test_create_prediction_default_energy_label(self, client: TestClient, sample_match: Match):
"""Test creating prediction with auto-generated energy label."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 75.0,
"away_energy": 65.0
}
)
assert response.status_code == 201
data = response.json()
assert data["energy_score"] in ["high", "very_high", "medium", "low"]
def test_create_prediction_home_win(self, client: TestClient, sample_match: Match):
"""Test prediction when home team wins."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 70.0,
"away_energy": 30.0
}
)
assert response.status_code == 201
assert response.json()["predicted_winner"] == "PSG"
def test_create_prediction_away_win(self, client: TestClient, sample_match: Match):
"""Test prediction when away team wins."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 30.0,
"away_energy": 70.0
}
)
assert response.status_code == 201
assert response.json()["predicted_winner"] == "Olympique de Marseille"
def test_create_prediction_draw(self, client: TestClient, sample_match: Match):
"""Test prediction for draw."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 50.0,
"away_energy": 50.0
}
)
assert response.status_code == 201
assert response.json()["predicted_winner"] == "Draw"
assert response.json()["confidence"] == "0.0%"
def test_create_prediction_nonexistent_match(self, client: TestClient):
"""Test creating prediction for non-existent match."""
response = client.post(
"/api/v1/predictions/matches/999/predict",
params={
"home_energy": 65.0,
"away_energy": 45.0
}
)
assert response.status_code == 400
assert "not found" in response.json()["detail"].lower()
def test_create_prediction_negative_energy(self, client: TestClient, sample_match: Match):
"""Test creating prediction with negative energy."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": -10.0,
"away_energy": 45.0
}
)
assert response.status_code == 400
class TestGetPredictionEndpoint:
"""Test GET /api/v1/predictions/{prediction_id} endpoint."""
def test_get_prediction_success(self, client: TestClient, sample_prediction: Prediction):
"""Test getting a prediction by ID successfully."""
response = client.get(f"/api/v1/predictions/{sample_prediction.id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == sample_prediction.id
assert data["match_id"] == sample_prediction.match_id
assert data["energy_score"] == sample_prediction.energy_score
assert data["confidence"] == sample_prediction.confidence
assert data["predicted_winner"] == sample_prediction.predicted_winner
def test_get_prediction_not_found(self, client: TestClient):
"""Test getting non-existent prediction."""
response = client.get("/api/v1/predictions/999")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
class TestGetPredictionsForMatchEndpoint:
"""Test GET /api/v1/predictions/matches/{match_id} endpoint."""
def test_get_predictions_for_match_success(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting all predictions for a match."""
# Create multiple predictions
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
for i in range(3):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0 + i * 10,
away_energy=40.0 + i * 5
)
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "count" in data
assert data["count"] == 3
assert len(data["data"]) == 3
assert all(p["match_id"] == sample_match.id for p in data["data"])
def test_get_predictions_for_empty_match(self, client: TestClient, sample_match: Match):
"""Test getting predictions for match with no predictions."""
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}")
assert response.status_code == 200
data = response.json()
assert data["count"] == 0
assert len(data["data"]) == 0
class TestGetLatestPredictionForMatchEndpoint:
"""Test GET /api/v1/predictions/matches/{match_id}/latest endpoint."""
def test_get_latest_prediction_success(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting latest prediction for a match."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create multiple predictions
prediction1 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0,
away_energy=40.0
)
import time
time.sleep(0.01)
prediction2 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=60.0,
away_energy=35.0
)
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest")
assert response.status_code == 200
data = response.json()
assert data["id"] == prediction2.id
assert data["match_id"] == sample_match.id
def test_get_latest_prediction_not_found(self, client: TestClient, sample_match: Match):
"""Test getting latest prediction for match with no predictions."""
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest")
assert response.status_code == 404
assert "No predictions found" in response.json()["detail"]
class TestDeletePredictionEndpoint:
"""Test DELETE /api/v1/predictions/{prediction_id} endpoint."""
def test_delete_prediction_success(self, client: TestClient, sample_prediction: Prediction):
"""Test deleting a prediction successfully."""
response = client.delete(f"/api/v1/predictions/{sample_prediction.id}")
assert response.status_code == 204
# Verify prediction is deleted
get_response = client.get(f"/api/v1/predictions/{sample_prediction.id}")
assert get_response.status_code == 404
def test_delete_prediction_not_found(self, client: TestClient):
"""Test deleting non-existent prediction."""
response = client.delete("/api/v1/predictions/999")
assert response.status_code == 404
class TestGetPredictionsEndpoint:
"""Test GET /api/v1/predictions endpoint with pagination and filters."""
def test_get_predictions_default(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting predictions with default parameters."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create multiple predictions
for i in range(5):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0 + i * 5,
away_energy=40.0 + i * 3
)
response = client.get("/api/v1/predictions")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "meta" in data
assert "total" in data["meta"]
assert "limit" in data["meta"]
assert "offset" in data["meta"]
assert "timestamp" in data["meta"]
assert data["meta"]["version"] == "v1"
assert data["meta"]["limit"] == 20
assert data["meta"]["offset"] == 0
def test_get_predictions_with_pagination(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting predictions with custom pagination."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create 25 predictions
for i in range(25):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0 + i,
away_energy=40.0 + i
)
# Get first page with limit 10
response = client.get("/api/v1/predictions?limit=10&offset=0")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 10
assert data["meta"]["limit"] == 10
assert data["meta"]["offset"] == 0
assert data["meta"]["total"] == 25
# Get second page
response2 = client.get("/api/v1/predictions?limit=10&offset=10")
assert response2.status_code == 200
data2 = response2.json()
assert len(data2["data"]) == 10
assert data2["meta"]["offset"] == 10
def test_get_predictions_with_league_filter(self, client: TestClient, db_session: Session):
"""Test filtering predictions by league."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create matches in different leagues
match1 = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
match2 = Match(
home_team="Real Madrid",
away_team="Barcelona",
date=datetime.now(timezone.utc),
league="La Liga",
status="scheduled"
)
db_session.add_all([match1, match2])
db_session.commit()
db_session.refresh(match1)
db_session.refresh(match2)
# Create predictions for both matches
service.create_prediction_for_match(match1.id, 65.0, 45.0)
service.create_prediction_for_match(match2.id, 60.0, 50.0)
# Filter by Ligue 1
response = client.get("/api/v1/predictions?league=Ligue 1")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
assert data["data"][0]["match"]["league"] == "Ligue 1"
def test_get_predictions_with_date_filter(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test filtering predictions by date range."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create predictions for matches with different dates
future_date = datetime.now(timezone.utc).replace(day=20, month=2)
past_date = datetime.now(timezone.utc).replace(day=1, month=1, year=2025)
future_match = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=future_date,
league="Ligue 1",
status="scheduled"
)
past_match = Match(
home_team="Real Madrid",
away_team="Barcelona",
date=past_date,
league="La Liga",
status="scheduled"
)
db_session.add_all([future_match, past_match])
db_session.commit()
db_session.refresh(future_match)
db_session.refresh(past_match)
service.create_prediction_for_match(future_match.id, 65.0, 45.0)
service.create_prediction_for_match(past_match.id, 60.0, 50.0)
# Filter by date range
response = client.get(
f"/api/v1/predictions?date_min=2025-01-01T00:00:00Z"
)
assert response.status_code == 200
data = response.json()
# Should include future match but not past match
assert len(data["data"]) >= 1
assert any(pred["match"]["date"] >= "2025-01-01" for pred in data["data"])
def test_get_predictions_invalid_limit(self, client: TestClient):
"""Test getting predictions with invalid limit."""
response = client.get("/api/v1/predictions?limit=150")
# Pydantic should validate this before endpoint executes
assert response.status_code == 422
def test_get_predictions_empty(self, client: TestClient):
"""Test getting predictions when none exist."""
response = client.get("/api/v1/predictions")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 0
assert data["meta"]["total"] == 0
class TestGetPredictionByMatchIdEndpoint:
"""Test GET /api/v1/predictions/match/{match_id} endpoint."""
def test_get_prediction_by_match_id_success(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting prediction details by match ID."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create a prediction
prediction = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=65.0,
away_energy=45.0
)
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == prediction.id
assert data["match_id"] == sample_match.id
assert "match" in data
assert data["match"]["home_team"] == "PSG"
assert data["match"]["away_team"] == "Olympique de Marseille"
assert "energy_score" in data
assert "confidence" in data
assert "predicted_winner" in data
assert "history" in data
assert len(data["history"]) >= 1
def test_get_prediction_by_match_id_with_history(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting prediction with history."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create multiple predictions for the same match
import time
prediction1 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0,
away_energy=40.0
)
time.sleep(0.01)
prediction2 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=60.0,
away_energy=35.0
)
time.sleep(0.01)
prediction3 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=70.0,
away_energy=30.0
)
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
assert response.status_code == 200
data = response.json()
# Should return latest prediction
assert data["id"] == prediction3.id
# History should contain all predictions
assert len(data["history"]) == 3
assert data["history"][0]["id"] == prediction3.id
assert data["history"][1]["id"] == prediction2.id
assert data["history"][2]["id"] == prediction1.id
def test_get_prediction_by_match_id_not_found(self, client: TestClient):
"""Test getting prediction for non-existent match."""
response = client.get("/api/v1/predictions/match/999")
assert response.status_code == 404
assert "No predictions found" in response.json()["detail"]
def test_get_prediction_by_match_id_no_predictions(self, client: TestClient, sample_match: Match):
"""Test getting prediction for match without predictions."""
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
assert response.status_code == 404