""" Unit tests for Prediction API endpoints. """ import pytest from fastapi.testclient import TestClient from datetime import datetime, timezone from sqlalchemy.orm import Session from app.database import Base, engine, SessionLocal from app.main import app from app.models.match import Match from app.models.prediction import Prediction @pytest.fixture(scope="function") def db_session(): """Create a fresh database session for each test.""" Base.metadata.create_all(bind=engine) session = SessionLocal() try: yield session session.rollback() finally: session.close() Base.metadata.drop_all(bind=engine) @pytest.fixture def client(): """Create a test client.""" return TestClient(app) @pytest.fixture def sample_match(db_session: Session) -> Match: """Create a sample match for testing.""" match = Match( home_team="PSG", away_team="Olympique de Marseille", date=datetime.now(timezone.utc), league="Ligue 1", status="scheduled" ) db_session.add(match) db_session.commit() db_session.refresh(match) return match @pytest.fixture def sample_prediction(db_session: Session, sample_match: Match) -> Prediction: """Create a sample prediction for testing.""" prediction = Prediction( match_id=sample_match.id, energy_score="high", confidence="75.0%", predicted_winner="PSG", created_at=datetime.now(timezone.utc) ) db_session.add(prediction) db_session.commit() db_session.refresh(prediction) return prediction class TestCreatePredictionEndpoint: """Test POST /api/v1/predictions/matches/{match_id}/predict endpoint.""" def test_create_prediction_success(self, client: TestClient, sample_match: Match): """Test creating a prediction successfully.""" response = client.post( f"/api/v1/predictions/matches/{sample_match.id}/predict", params={ "home_energy": 65.0, "away_energy": 45.0, "energy_score_label": "high" } ) assert response.status_code == 201 data = response.json() assert data["match_id"] == sample_match.id assert data["energy_score"] == "high" assert "%" in data["confidence"] assert data["predicted_winner"] == "PSG" assert "id" in data assert "created_at" in data def test_create_prediction_default_energy_label(self, client: TestClient, sample_match: Match): """Test creating prediction with auto-generated energy label.""" response = client.post( f"/api/v1/predictions/matches/{sample_match.id}/predict", params={ "home_energy": 75.0, "away_energy": 65.0 } ) assert response.status_code == 201 data = response.json() assert data["energy_score"] in ["high", "very_high", "medium", "low"] def test_create_prediction_home_win(self, client: TestClient, sample_match: Match): """Test prediction when home team wins.""" response = client.post( f"/api/v1/predictions/matches/{sample_match.id}/predict", params={ "home_energy": 70.0, "away_energy": 30.0 } ) assert response.status_code == 201 assert response.json()["predicted_winner"] == "PSG" def test_create_prediction_away_win(self, client: TestClient, sample_match: Match): """Test prediction when away team wins.""" response = client.post( f"/api/v1/predictions/matches/{sample_match.id}/predict", params={ "home_energy": 30.0, "away_energy": 70.0 } ) assert response.status_code == 201 assert response.json()["predicted_winner"] == "Olympique de Marseille" def test_create_prediction_draw(self, client: TestClient, sample_match: Match): """Test prediction for draw.""" response = client.post( f"/api/v1/predictions/matches/{sample_match.id}/predict", params={ "home_energy": 50.0, "away_energy": 50.0 } ) assert response.status_code == 201 assert response.json()["predicted_winner"] == "Draw" assert response.json()["confidence"] == "0.0%" def test_create_prediction_nonexistent_match(self, client: TestClient): """Test creating prediction for non-existent match.""" response = client.post( "/api/v1/predictions/matches/999/predict", params={ "home_energy": 65.0, "away_energy": 45.0 } ) assert response.status_code == 400 assert "not found" in response.json()["detail"].lower() def test_create_prediction_negative_energy(self, client: TestClient, sample_match: Match): """Test creating prediction with negative energy.""" response = client.post( f"/api/v1/predictions/matches/{sample_match.id}/predict", params={ "home_energy": -10.0, "away_energy": 45.0 } ) assert response.status_code == 400 class TestGetPredictionEndpoint: """Test GET /api/v1/predictions/{prediction_id} endpoint.""" def test_get_prediction_success(self, client: TestClient, sample_prediction: Prediction): """Test getting a prediction by ID successfully.""" response = client.get(f"/api/v1/predictions/{sample_prediction.id}") assert response.status_code == 200 data = response.json() assert data["id"] == sample_prediction.id assert data["match_id"] == sample_prediction.match_id assert data["energy_score"] == sample_prediction.energy_score assert data["confidence"] == sample_prediction.confidence assert data["predicted_winner"] == sample_prediction.predicted_winner def test_get_prediction_not_found(self, client: TestClient): """Test getting non-existent prediction.""" response = client.get("/api/v1/predictions/999") assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() class TestGetPredictionsForMatchEndpoint: """Test GET /api/v1/predictions/matches/{match_id} endpoint.""" def test_get_predictions_for_match_success(self, client: TestClient, db_session: Session, sample_match: Match): """Test getting all predictions for a match.""" # Create multiple predictions from app.services.prediction_service import PredictionService service = PredictionService(db_session) for i in range(3): service.create_prediction_for_match( match_id=sample_match.id, home_energy=50.0 + i * 10, away_energy=40.0 + i * 5 ) response = client.get(f"/api/v1/predictions/matches/{sample_match.id}") assert response.status_code == 200 data = response.json() assert "data" in data assert "count" in data assert data["count"] == 3 assert len(data["data"]) == 3 assert all(p["match_id"] == sample_match.id for p in data["data"]) def test_get_predictions_for_empty_match(self, client: TestClient, sample_match: Match): """Test getting predictions for match with no predictions.""" response = client.get(f"/api/v1/predictions/matches/{sample_match.id}") assert response.status_code == 200 data = response.json() assert data["count"] == 0 assert len(data["data"]) == 0 class TestGetLatestPredictionForMatchEndpoint: """Test GET /api/v1/predictions/matches/{match_id}/latest endpoint.""" def test_get_latest_prediction_success(self, client: TestClient, db_session: Session, sample_match: Match): """Test getting latest prediction for a match.""" from app.services.prediction_service import PredictionService service = PredictionService(db_session) # Create multiple predictions prediction1 = service.create_prediction_for_match( match_id=sample_match.id, home_energy=50.0, away_energy=40.0 ) import time time.sleep(0.01) prediction2 = service.create_prediction_for_match( match_id=sample_match.id, home_energy=60.0, away_energy=35.0 ) response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest") assert response.status_code == 200 data = response.json() assert data["id"] == prediction2.id assert data["match_id"] == sample_match.id def test_get_latest_prediction_not_found(self, client: TestClient, sample_match: Match): """Test getting latest prediction for match with no predictions.""" response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest") assert response.status_code == 404 assert "No predictions found" in response.json()["detail"] class TestDeletePredictionEndpoint: """Test DELETE /api/v1/predictions/{prediction_id} endpoint.""" def test_delete_prediction_success(self, client: TestClient, sample_prediction: Prediction): """Test deleting a prediction successfully.""" response = client.delete(f"/api/v1/predictions/{sample_prediction.id}") assert response.status_code == 204 # Verify prediction is deleted get_response = client.get(f"/api/v1/predictions/{sample_prediction.id}") assert get_response.status_code == 404 def test_delete_prediction_not_found(self, client: TestClient): """Test deleting non-existent prediction.""" response = client.delete("/api/v1/predictions/999") assert response.status_code == 404 class TestGetPredictionsEndpoint: """Test GET /api/v1/predictions endpoint with pagination and filters.""" def test_get_predictions_default(self, client: TestClient, db_session: Session, sample_match: Match): """Test getting predictions with default parameters.""" from app.services.prediction_service import PredictionService service = PredictionService(db_session) # Create multiple predictions for i in range(5): service.create_prediction_for_match( match_id=sample_match.id, home_energy=50.0 + i * 5, away_energy=40.0 + i * 3 ) response = client.get("/api/v1/predictions") assert response.status_code == 200 data = response.json() assert "data" in data assert "meta" in data assert "total" in data["meta"] assert "limit" in data["meta"] assert "offset" in data["meta"] assert "timestamp" in data["meta"] assert data["meta"]["version"] == "v1" assert data["meta"]["limit"] == 20 assert data["meta"]["offset"] == 0 def test_get_predictions_with_pagination(self, client: TestClient, db_session: Session, sample_match: Match): """Test getting predictions with custom pagination.""" from app.services.prediction_service import PredictionService service = PredictionService(db_session) # Create 25 predictions for i in range(25): service.create_prediction_for_match( match_id=sample_match.id, home_energy=50.0 + i, away_energy=40.0 + i ) # Get first page with limit 10 response = client.get("/api/v1/predictions?limit=10&offset=0") assert response.status_code == 200 data = response.json() assert len(data["data"]) == 10 assert data["meta"]["limit"] == 10 assert data["meta"]["offset"] == 0 assert data["meta"]["total"] == 25 # Get second page response2 = client.get("/api/v1/predictions?limit=10&offset=10") assert response2.status_code == 200 data2 = response2.json() assert len(data2["data"]) == 10 assert data2["meta"]["offset"] == 10 def test_get_predictions_with_league_filter(self, client: TestClient, db_session: Session): """Test filtering predictions by league.""" from app.services.prediction_service import PredictionService service = PredictionService(db_session) # Create matches in different leagues match1 = Match( home_team="PSG", away_team="Olympique de Marseille", date=datetime.now(timezone.utc), league="Ligue 1", status="scheduled" ) match2 = Match( home_team="Real Madrid", away_team="Barcelona", date=datetime.now(timezone.utc), league="La Liga", status="scheduled" ) db_session.add_all([match1, match2]) db_session.commit() db_session.refresh(match1) db_session.refresh(match2) # Create predictions for both matches service.create_prediction_for_match(match1.id, 65.0, 45.0) service.create_prediction_for_match(match2.id, 60.0, 50.0) # Filter by Ligue 1 response = client.get("/api/v1/predictions?league=Ligue 1") assert response.status_code == 200 data = response.json() assert len(data["data"]) == 1 assert data["data"][0]["match"]["league"] == "Ligue 1" def test_get_predictions_with_date_filter(self, client: TestClient, db_session: Session, sample_match: Match): """Test filtering predictions by date range.""" from app.services.prediction_service import PredictionService service = PredictionService(db_session) # Create predictions for matches with different dates future_date = datetime.now(timezone.utc).replace(day=20, month=2) past_date = datetime.now(timezone.utc).replace(day=1, month=1, year=2025) future_match = Match( home_team="PSG", away_team="Olympique de Marseille", date=future_date, league="Ligue 1", status="scheduled" ) past_match = Match( home_team="Real Madrid", away_team="Barcelona", date=past_date, league="La Liga", status="scheduled" ) db_session.add_all([future_match, past_match]) db_session.commit() db_session.refresh(future_match) db_session.refresh(past_match) service.create_prediction_for_match(future_match.id, 65.0, 45.0) service.create_prediction_for_match(past_match.id, 60.0, 50.0) # Filter by date range response = client.get( f"/api/v1/predictions?date_min=2025-01-01T00:00:00Z" ) assert response.status_code == 200 data = response.json() # Should include future match but not past match assert len(data["data"]) >= 1 assert any(pred["match"]["date"] >= "2025-01-01" for pred in data["data"]) def test_get_predictions_invalid_limit(self, client: TestClient): """Test getting predictions with invalid limit.""" response = client.get("/api/v1/predictions?limit=150") # Pydantic should validate this before endpoint executes assert response.status_code == 422 def test_get_predictions_empty(self, client: TestClient): """Test getting predictions when none exist.""" response = client.get("/api/v1/predictions") assert response.status_code == 200 data = response.json() assert len(data["data"]) == 0 assert data["meta"]["total"] == 0 class TestGetPredictionByMatchIdEndpoint: """Test GET /api/v1/predictions/match/{match_id} endpoint.""" def test_get_prediction_by_match_id_success(self, client: TestClient, db_session: Session, sample_match: Match): """Test getting prediction details by match ID.""" from app.services.prediction_service import PredictionService service = PredictionService(db_session) # Create a prediction prediction = service.create_prediction_for_match( match_id=sample_match.id, home_energy=65.0, away_energy=45.0 ) response = client.get(f"/api/v1/predictions/match/{sample_match.id}") assert response.status_code == 200 data = response.json() assert data["id"] == prediction.id assert data["match_id"] == sample_match.id assert "match" in data assert data["match"]["home_team"] == "PSG" assert data["match"]["away_team"] == "Olympique de Marseille" assert "energy_score" in data assert "confidence" in data assert "predicted_winner" in data assert "history" in data assert len(data["history"]) >= 1 def test_get_prediction_by_match_id_with_history(self, client: TestClient, db_session: Session, sample_match: Match): """Test getting prediction with history.""" from app.services.prediction_service import PredictionService service = PredictionService(db_session) # Create multiple predictions for the same match import time prediction1 = service.create_prediction_for_match( match_id=sample_match.id, home_energy=50.0, away_energy=40.0 ) time.sleep(0.01) prediction2 = service.create_prediction_for_match( match_id=sample_match.id, home_energy=60.0, away_energy=35.0 ) time.sleep(0.01) prediction3 = service.create_prediction_for_match( match_id=sample_match.id, home_energy=70.0, away_energy=30.0 ) response = client.get(f"/api/v1/predictions/match/{sample_match.id}") assert response.status_code == 200 data = response.json() # Should return latest prediction assert data["id"] == prediction3.id # History should contain all predictions assert len(data["history"]) == 3 assert data["history"][0]["id"] == prediction3.id assert data["history"][1]["id"] == prediction2.id assert data["history"][2]["id"] == prediction1.id def test_get_prediction_by_match_id_not_found(self, client: TestClient): """Test getting prediction for non-existent match.""" response = client.get("/api/v1/predictions/match/999") assert response.status_code == 404 assert "No predictions found" in response.json()["detail"] def test_get_prediction_by_match_id_no_predictions(self, client: TestClient, sample_match: Match): """Test getting prediction for match without predictions.""" response = client.get(f"/api/v1/predictions/match/{sample_match.id}") assert response.status_code == 404