538 lines
19 KiB
Python
538 lines
19 KiB
Python
"""
|
|
Unit tests for Prediction API endpoints.
|
|
"""
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.database import Base, engine, SessionLocal
|
|
from app.main import app
|
|
from app.models.match import Match
|
|
from app.models.prediction import Prediction
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def db_session():
|
|
"""Create a fresh database session for each test."""
|
|
Base.metadata.create_all(bind=engine)
|
|
session = SessionLocal()
|
|
try:
|
|
yield session
|
|
session.rollback()
|
|
finally:
|
|
session.close()
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create a test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_match(db_session: Session) -> Match:
|
|
"""Create a sample match for testing."""
|
|
match = Match(
|
|
home_team="PSG",
|
|
away_team="Olympique de Marseille",
|
|
date=datetime.now(timezone.utc),
|
|
league="Ligue 1",
|
|
status="scheduled"
|
|
)
|
|
db_session.add(match)
|
|
db_session.commit()
|
|
db_session.refresh(match)
|
|
return match
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_prediction(db_session: Session, sample_match: Match) -> Prediction:
|
|
"""Create a sample prediction for testing."""
|
|
prediction = Prediction(
|
|
match_id=sample_match.id,
|
|
energy_score="high",
|
|
confidence="75.0%",
|
|
predicted_winner="PSG",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
db_session.add(prediction)
|
|
db_session.commit()
|
|
db_session.refresh(prediction)
|
|
return prediction
|
|
|
|
|
|
class TestCreatePredictionEndpoint:
|
|
"""Test POST /api/v1/predictions/matches/{match_id}/predict endpoint."""
|
|
|
|
def test_create_prediction_success(self, client: TestClient, sample_match: Match):
|
|
"""Test creating a prediction successfully."""
|
|
response = client.post(
|
|
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
|
params={
|
|
"home_energy": 65.0,
|
|
"away_energy": 45.0,
|
|
"energy_score_label": "high"
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
|
|
assert data["match_id"] == sample_match.id
|
|
assert data["energy_score"] == "high"
|
|
assert "%" in data["confidence"]
|
|
assert data["predicted_winner"] == "PSG"
|
|
assert "id" in data
|
|
assert "created_at" in data
|
|
|
|
def test_create_prediction_default_energy_label(self, client: TestClient, sample_match: Match):
|
|
"""Test creating prediction with auto-generated energy label."""
|
|
response = client.post(
|
|
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
|
params={
|
|
"home_energy": 75.0,
|
|
"away_energy": 65.0
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
|
|
assert data["energy_score"] in ["high", "very_high", "medium", "low"]
|
|
|
|
def test_create_prediction_home_win(self, client: TestClient, sample_match: Match):
|
|
"""Test prediction when home team wins."""
|
|
response = client.post(
|
|
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
|
params={
|
|
"home_energy": 70.0,
|
|
"away_energy": 30.0
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
assert response.json()["predicted_winner"] == "PSG"
|
|
|
|
def test_create_prediction_away_win(self, client: TestClient, sample_match: Match):
|
|
"""Test prediction when away team wins."""
|
|
response = client.post(
|
|
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
|
params={
|
|
"home_energy": 30.0,
|
|
"away_energy": 70.0
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
assert response.json()["predicted_winner"] == "Olympique de Marseille"
|
|
|
|
def test_create_prediction_draw(self, client: TestClient, sample_match: Match):
|
|
"""Test prediction for draw."""
|
|
response = client.post(
|
|
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
|
params={
|
|
"home_energy": 50.0,
|
|
"away_energy": 50.0
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
assert response.json()["predicted_winner"] == "Draw"
|
|
assert response.json()["confidence"] == "0.0%"
|
|
|
|
def test_create_prediction_nonexistent_match(self, client: TestClient):
|
|
"""Test creating prediction for non-existent match."""
|
|
response = client.post(
|
|
"/api/v1/predictions/matches/999/predict",
|
|
params={
|
|
"home_energy": 65.0,
|
|
"away_energy": 45.0
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "not found" in response.json()["detail"].lower()
|
|
|
|
def test_create_prediction_negative_energy(self, client: TestClient, sample_match: Match):
|
|
"""Test creating prediction with negative energy."""
|
|
response = client.post(
|
|
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
|
params={
|
|
"home_energy": -10.0,
|
|
"away_energy": 45.0
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
|
|
class TestGetPredictionEndpoint:
|
|
"""Test GET /api/v1/predictions/{prediction_id} endpoint."""
|
|
|
|
def test_get_prediction_success(self, client: TestClient, sample_prediction: Prediction):
|
|
"""Test getting a prediction by ID successfully."""
|
|
response = client.get(f"/api/v1/predictions/{sample_prediction.id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["id"] == sample_prediction.id
|
|
assert data["match_id"] == sample_prediction.match_id
|
|
assert data["energy_score"] == sample_prediction.energy_score
|
|
assert data["confidence"] == sample_prediction.confidence
|
|
assert data["predicted_winner"] == sample_prediction.predicted_winner
|
|
|
|
def test_get_prediction_not_found(self, client: TestClient):
|
|
"""Test getting non-existent prediction."""
|
|
response = client.get("/api/v1/predictions/999")
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"].lower()
|
|
|
|
|
|
class TestGetPredictionsForMatchEndpoint:
|
|
"""Test GET /api/v1/predictions/matches/{match_id} endpoint."""
|
|
|
|
def test_get_predictions_for_match_success(self, client: TestClient, db_session: Session, sample_match: Match):
|
|
"""Test getting all predictions for a match."""
|
|
# Create multiple predictions
|
|
from app.services.prediction_service import PredictionService
|
|
service = PredictionService(db_session)
|
|
|
|
for i in range(3):
|
|
service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=50.0 + i * 10,
|
|
away_energy=40.0 + i * 5
|
|
)
|
|
|
|
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "data" in data
|
|
assert "count" in data
|
|
assert data["count"] == 3
|
|
assert len(data["data"]) == 3
|
|
assert all(p["match_id"] == sample_match.id for p in data["data"])
|
|
|
|
def test_get_predictions_for_empty_match(self, client: TestClient, sample_match: Match):
|
|
"""Test getting predictions for match with no predictions."""
|
|
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["count"] == 0
|
|
assert len(data["data"]) == 0
|
|
|
|
|
|
class TestGetLatestPredictionForMatchEndpoint:
|
|
"""Test GET /api/v1/predictions/matches/{match_id}/latest endpoint."""
|
|
|
|
def test_get_latest_prediction_success(self, client: TestClient, db_session: Session, sample_match: Match):
|
|
"""Test getting latest prediction for a match."""
|
|
from app.services.prediction_service import PredictionService
|
|
service = PredictionService(db_session)
|
|
|
|
# Create multiple predictions
|
|
prediction1 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=50.0,
|
|
away_energy=40.0
|
|
)
|
|
|
|
import time
|
|
time.sleep(0.01)
|
|
|
|
prediction2 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=60.0,
|
|
away_energy=35.0
|
|
)
|
|
|
|
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["id"] == prediction2.id
|
|
assert data["match_id"] == sample_match.id
|
|
|
|
def test_get_latest_prediction_not_found(self, client: TestClient, sample_match: Match):
|
|
"""Test getting latest prediction for match with no predictions."""
|
|
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest")
|
|
|
|
assert response.status_code == 404
|
|
assert "No predictions found" in response.json()["detail"]
|
|
|
|
|
|
class TestDeletePredictionEndpoint:
|
|
"""Test DELETE /api/v1/predictions/{prediction_id} endpoint."""
|
|
|
|
def test_delete_prediction_success(self, client: TestClient, sample_prediction: Prediction):
|
|
"""Test deleting a prediction successfully."""
|
|
response = client.delete(f"/api/v1/predictions/{sample_prediction.id}")
|
|
|
|
assert response.status_code == 204
|
|
|
|
# Verify prediction is deleted
|
|
get_response = client.get(f"/api/v1/predictions/{sample_prediction.id}")
|
|
assert get_response.status_code == 404
|
|
|
|
def test_delete_prediction_not_found(self, client: TestClient):
|
|
"""Test deleting non-existent prediction."""
|
|
response = client.delete("/api/v1/predictions/999")
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestGetPredictionsEndpoint:
|
|
"""Test GET /api/v1/predictions endpoint with pagination and filters."""
|
|
|
|
def test_get_predictions_default(self, client: TestClient, db_session: Session, sample_match: Match):
|
|
"""Test getting predictions with default parameters."""
|
|
from app.services.prediction_service import PredictionService
|
|
service = PredictionService(db_session)
|
|
|
|
# Create multiple predictions
|
|
for i in range(5):
|
|
service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=50.0 + i * 5,
|
|
away_energy=40.0 + i * 3
|
|
)
|
|
|
|
response = client.get("/api/v1/predictions")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "data" in data
|
|
assert "meta" in data
|
|
assert "total" in data["meta"]
|
|
assert "limit" in data["meta"]
|
|
assert "offset" in data["meta"]
|
|
assert "timestamp" in data["meta"]
|
|
assert data["meta"]["version"] == "v1"
|
|
assert data["meta"]["limit"] == 20
|
|
assert data["meta"]["offset"] == 0
|
|
|
|
def test_get_predictions_with_pagination(self, client: TestClient, db_session: Session, sample_match: Match):
|
|
"""Test getting predictions with custom pagination."""
|
|
from app.services.prediction_service import PredictionService
|
|
service = PredictionService(db_session)
|
|
|
|
# Create 25 predictions
|
|
for i in range(25):
|
|
service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=50.0 + i,
|
|
away_energy=40.0 + i
|
|
)
|
|
|
|
# Get first page with limit 10
|
|
response = client.get("/api/v1/predictions?limit=10&offset=0")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert len(data["data"]) == 10
|
|
assert data["meta"]["limit"] == 10
|
|
assert data["meta"]["offset"] == 0
|
|
assert data["meta"]["total"] == 25
|
|
|
|
# Get second page
|
|
response2 = client.get("/api/v1/predictions?limit=10&offset=10")
|
|
|
|
assert response2.status_code == 200
|
|
data2 = response2.json()
|
|
|
|
assert len(data2["data"]) == 10
|
|
assert data2["meta"]["offset"] == 10
|
|
|
|
def test_get_predictions_with_league_filter(self, client: TestClient, db_session: Session):
|
|
"""Test filtering predictions by league."""
|
|
from app.services.prediction_service import PredictionService
|
|
service = PredictionService(db_session)
|
|
|
|
# Create matches in different leagues
|
|
match1 = Match(
|
|
home_team="PSG",
|
|
away_team="Olympique de Marseille",
|
|
date=datetime.now(timezone.utc),
|
|
league="Ligue 1",
|
|
status="scheduled"
|
|
)
|
|
match2 = Match(
|
|
home_team="Real Madrid",
|
|
away_team="Barcelona",
|
|
date=datetime.now(timezone.utc),
|
|
league="La Liga",
|
|
status="scheduled"
|
|
)
|
|
db_session.add_all([match1, match2])
|
|
db_session.commit()
|
|
db_session.refresh(match1)
|
|
db_session.refresh(match2)
|
|
|
|
# Create predictions for both matches
|
|
service.create_prediction_for_match(match1.id, 65.0, 45.0)
|
|
service.create_prediction_for_match(match2.id, 60.0, 50.0)
|
|
|
|
# Filter by Ligue 1
|
|
response = client.get("/api/v1/predictions?league=Ligue 1")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert len(data["data"]) == 1
|
|
assert data["data"][0]["match"]["league"] == "Ligue 1"
|
|
|
|
def test_get_predictions_with_date_filter(self, client: TestClient, db_session: Session, sample_match: Match):
|
|
"""Test filtering predictions by date range."""
|
|
from app.services.prediction_service import PredictionService
|
|
service = PredictionService(db_session)
|
|
|
|
# Create predictions for matches with different dates
|
|
future_date = datetime.now(timezone.utc).replace(day=20, month=2)
|
|
past_date = datetime.now(timezone.utc).replace(day=1, month=1, year=2025)
|
|
|
|
future_match = Match(
|
|
home_team="PSG",
|
|
away_team="Olympique de Marseille",
|
|
date=future_date,
|
|
league="Ligue 1",
|
|
status="scheduled"
|
|
)
|
|
past_match = Match(
|
|
home_team="Real Madrid",
|
|
away_team="Barcelona",
|
|
date=past_date,
|
|
league="La Liga",
|
|
status="scheduled"
|
|
)
|
|
db_session.add_all([future_match, past_match])
|
|
db_session.commit()
|
|
db_session.refresh(future_match)
|
|
db_session.refresh(past_match)
|
|
|
|
service.create_prediction_for_match(future_match.id, 65.0, 45.0)
|
|
service.create_prediction_for_match(past_match.id, 60.0, 50.0)
|
|
|
|
# Filter by date range
|
|
response = client.get(
|
|
f"/api/v1/predictions?date_min=2025-01-01T00:00:00Z"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Should include future match but not past match
|
|
assert len(data["data"]) >= 1
|
|
assert any(pred["match"]["date"] >= "2025-01-01" for pred in data["data"])
|
|
|
|
def test_get_predictions_invalid_limit(self, client: TestClient):
|
|
"""Test getting predictions with invalid limit."""
|
|
response = client.get("/api/v1/predictions?limit=150")
|
|
|
|
# Pydantic should validate this before endpoint executes
|
|
assert response.status_code == 422
|
|
|
|
def test_get_predictions_empty(self, client: TestClient):
|
|
"""Test getting predictions when none exist."""
|
|
response = client.get("/api/v1/predictions")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert len(data["data"]) == 0
|
|
assert data["meta"]["total"] == 0
|
|
|
|
|
|
class TestGetPredictionByMatchIdEndpoint:
|
|
"""Test GET /api/v1/predictions/match/{match_id} endpoint."""
|
|
|
|
def test_get_prediction_by_match_id_success(self, client: TestClient, db_session: Session, sample_match: Match):
|
|
"""Test getting prediction details by match ID."""
|
|
from app.services.prediction_service import PredictionService
|
|
service = PredictionService(db_session)
|
|
|
|
# Create a prediction
|
|
prediction = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=65.0,
|
|
away_energy=45.0
|
|
)
|
|
|
|
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["id"] == prediction.id
|
|
assert data["match_id"] == sample_match.id
|
|
assert "match" in data
|
|
assert data["match"]["home_team"] == "PSG"
|
|
assert data["match"]["away_team"] == "Olympique de Marseille"
|
|
assert "energy_score" in data
|
|
assert "confidence" in data
|
|
assert "predicted_winner" in data
|
|
assert "history" in data
|
|
assert len(data["history"]) >= 1
|
|
|
|
def test_get_prediction_by_match_id_with_history(self, client: TestClient, db_session: Session, sample_match: Match):
|
|
"""Test getting prediction with history."""
|
|
from app.services.prediction_service import PredictionService
|
|
service = PredictionService(db_session)
|
|
|
|
# Create multiple predictions for the same match
|
|
import time
|
|
prediction1 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=50.0,
|
|
away_energy=40.0
|
|
)
|
|
time.sleep(0.01)
|
|
prediction2 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=60.0,
|
|
away_energy=35.0
|
|
)
|
|
time.sleep(0.01)
|
|
prediction3 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=70.0,
|
|
away_energy=30.0
|
|
)
|
|
|
|
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Should return latest prediction
|
|
assert data["id"] == prediction3.id
|
|
|
|
# History should contain all predictions
|
|
assert len(data["history"]) == 3
|
|
assert data["history"][0]["id"] == prediction3.id
|
|
assert data["history"][1]["id"] == prediction2.id
|
|
assert data["history"][2]["id"] == prediction1.id
|
|
|
|
def test_get_prediction_by_match_id_not_found(self, client: TestClient):
|
|
"""Test getting prediction for non-existent match."""
|
|
response = client.get("/api/v1/predictions/match/999")
|
|
|
|
assert response.status_code == 404
|
|
assert "No predictions found" in response.json()["detail"]
|
|
|
|
def test_get_prediction_by_match_id_no_predictions(self, client: TestClient, sample_match: Match):
|
|
"""Test getting prediction for match without predictions."""
|
|
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
|
|
|
|
assert response.status_code == 404 |