259 lines
8.9 KiB
Python
259 lines
8.9 KiB
Python
"""
|
|
Unit tests for Prediction Service.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.database import Base, engine, SessionLocal
|
|
from app.models.match import Match
|
|
from app.models.prediction import Prediction
|
|
from app.services.prediction_service import PredictionService
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def db_session():
|
|
"""Create a fresh database session for each test."""
|
|
Base.metadata.create_all(bind=engine)
|
|
session = SessionLocal()
|
|
try:
|
|
yield session
|
|
session.rollback()
|
|
finally:
|
|
session.close()
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_match(db_session: Session) -> Match:
|
|
"""Create a sample match for testing."""
|
|
match = Match(
|
|
home_team="PSG",
|
|
away_team="Olympique de Marseille",
|
|
date=datetime.now(timezone.utc),
|
|
league="Ligue 1",
|
|
status="scheduled"
|
|
)
|
|
db_session.add(match)
|
|
db_session.commit()
|
|
db_session.refresh(match)
|
|
return match
|
|
|
|
|
|
class TestPredictionServiceCreate:
|
|
"""Test prediction service creation methods."""
|
|
|
|
def test_create_prediction_home_win(self, db_session: Session, sample_match: Match):
|
|
"""Test creating a prediction where home team wins."""
|
|
service = PredictionService(db_session)
|
|
|
|
prediction = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=65.0,
|
|
away_energy=45.0,
|
|
energy_score_label="high"
|
|
)
|
|
|
|
assert prediction.id is not None
|
|
assert prediction.match_id == sample_match.id
|
|
assert prediction.predicted_winner == "PSG"
|
|
assert prediction.energy_score == "high"
|
|
assert "%" in prediction.confidence
|
|
|
|
def test_create_prediction_away_win(self, db_session: Session, sample_match: Match):
|
|
"""Test creating a prediction where away team wins."""
|
|
service = PredictionService(db_session)
|
|
|
|
prediction = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=35.0,
|
|
away_energy=70.0
|
|
)
|
|
|
|
assert prediction.predicted_winner == "Olympique de Marseille"
|
|
assert prediction.id is not None
|
|
|
|
def test_create_prediction_draw(self, db_session: Session, sample_match: Match):
|
|
"""Test creating a prediction for draw."""
|
|
service = PredictionService(db_session)
|
|
|
|
prediction = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=50.0,
|
|
away_energy=50.0
|
|
)
|
|
|
|
assert prediction.predicted_winner == "Draw"
|
|
assert prediction.confidence == "0.0%"
|
|
|
|
def test_create_prediction_with_default_energy_label(self, db_session: Session, sample_match: Match):
|
|
"""Test creating prediction without energy score label."""
|
|
service = PredictionService(db_session)
|
|
|
|
# High energy prediction
|
|
prediction1 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=80.0,
|
|
away_energy=70.0
|
|
)
|
|
assert prediction1.energy_score in ["high", "very_high"]
|
|
|
|
# Low energy prediction
|
|
prediction2 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=20.0,
|
|
away_energy=10.0
|
|
)
|
|
assert prediction2.energy_score == "low"
|
|
|
|
def test_create_prediction_nonexistent_match(self, db_session: Session):
|
|
"""Test creating prediction for non-existent match raises error."""
|
|
service = PredictionService(db_session)
|
|
|
|
with pytest.raises(ValueError, match="Match with id 999 not found"):
|
|
service.create_prediction_for_match(
|
|
match_id=999,
|
|
home_energy=65.0,
|
|
away_energy=45.0
|
|
)
|
|
|
|
def test_create_prediction_negative_energy(self, db_session: Session, sample_match: Match):
|
|
"""Test creating prediction with negative energy raises error."""
|
|
service = PredictionService(db_session)
|
|
|
|
with pytest.raises(ValueError, match="cannot be negative"):
|
|
service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=-10.0,
|
|
away_energy=45.0
|
|
)
|
|
|
|
def test_create_prediction_invalid_energy_type(self, db_session: Session, sample_match: Match):
|
|
"""Test creating prediction with invalid energy type raises error."""
|
|
service = PredictionService(db_session)
|
|
|
|
with pytest.raises(ValueError, match="must be numeric"):
|
|
service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy="invalid",
|
|
away_energy=45.0
|
|
)
|
|
|
|
|
|
class TestPredictionServiceRetrieve:
|
|
"""Test prediction service retrieval methods."""
|
|
|
|
def test_get_prediction_by_id(self, db_session: Session, sample_match: Match):
|
|
"""Test retrieving a prediction by ID."""
|
|
service = PredictionService(db_session)
|
|
|
|
created = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=65.0,
|
|
away_energy=45.0
|
|
)
|
|
|
|
retrieved = service.get_prediction_by_id(created.id)
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.id == created.id
|
|
assert retrieved.match_id == created.match_id
|
|
|
|
def test_get_prediction_by_id_not_found(self, db_session: Session):
|
|
"""Test retrieving non-existent prediction returns None."""
|
|
service = PredictionService(db_session)
|
|
|
|
retrieved = service.get_prediction_by_id(999)
|
|
|
|
assert retrieved is None
|
|
|
|
def test_get_predictions_for_match(self, db_session: Session, sample_match: Match):
|
|
"""Test retrieving all predictions for a match."""
|
|
service = PredictionService(db_session)
|
|
|
|
# Create multiple predictions
|
|
for i in range(3):
|
|
service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=50.0 + i * 10,
|
|
away_energy=40.0 + i * 5
|
|
)
|
|
|
|
predictions = service.get_predictions_for_match(sample_match.id)
|
|
|
|
assert len(predictions) == 3
|
|
assert all(p.match_id == sample_match.id for p in predictions)
|
|
|
|
def test_get_predictions_for_empty_match(self, db_session: Session, sample_match: Match):
|
|
"""Test retrieving predictions for match with no predictions."""
|
|
service = PredictionService(db_session)
|
|
|
|
predictions = service.get_predictions_for_match(sample_match.id)
|
|
|
|
assert len(predictions) == 0
|
|
|
|
def test_get_latest_prediction_for_match(self, db_session: Session, sample_match: Match):
|
|
"""Test retrieving latest prediction for a match."""
|
|
service = PredictionService(db_session)
|
|
|
|
# Create predictions at different times
|
|
prediction1 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=50.0,
|
|
away_energy=40.0
|
|
)
|
|
|
|
# Small delay to ensure different timestamps
|
|
import time
|
|
time.sleep(0.01)
|
|
|
|
prediction2 = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=60.0,
|
|
away_energy=35.0
|
|
)
|
|
|
|
latest = service.get_latest_prediction_for_match(sample_match.id)
|
|
|
|
assert latest.id == prediction2.id
|
|
assert latest.id != prediction1.id
|
|
|
|
def test_get_latest_prediction_for_empty_match(self, db_session: Session, sample_match: Match):
|
|
"""Test retrieving latest prediction for match with no predictions."""
|
|
service = PredictionService(db_session)
|
|
|
|
latest = service.get_latest_prediction_for_match(sample_match.id)
|
|
|
|
assert latest is None
|
|
|
|
|
|
class TestPredictionServiceDelete:
|
|
"""Test prediction service delete methods."""
|
|
|
|
def test_delete_prediction(self, db_session: Session, sample_match: Match):
|
|
"""Test deleting a prediction."""
|
|
service = PredictionService(db_session)
|
|
|
|
prediction = service.create_prediction_for_match(
|
|
match_id=sample_match.id,
|
|
home_energy=65.0,
|
|
away_energy=45.0
|
|
)
|
|
|
|
deleted = service.delete_prediction(prediction.id)
|
|
|
|
assert deleted is True
|
|
|
|
# Verify prediction is gone
|
|
retrieved = service.get_prediction_by_id(prediction.id)
|
|
assert retrieved is None
|
|
|
|
def test_delete_prediction_not_found(self, db_session: Session):
|
|
"""Test deleting non-existent prediction."""
|
|
service = PredictionService(db_session)
|
|
|
|
deleted = service.delete_prediction(999)
|
|
|
|
assert deleted is False
|