256 lines
7.7 KiB
Python
256 lines
7.7 KiB
Python
"""
|
|
Unit tests for Prediction model.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.match import Match
|
|
from app.models.prediction import Prediction
|
|
from app.database import Base, engine, SessionLocal
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def db_session():
|
|
"""Create a fresh database session for each test."""
|
|
# Create tables
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Create session
|
|
session = SessionLocal()
|
|
try:
|
|
yield session
|
|
session.rollback()
|
|
finally:
|
|
session.close()
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
|
|
class TestPredictionModel:
|
|
"""Test Prediction SQLAlchemy model."""
|
|
|
|
def test_prediction_creation(self, db_session: Session):
|
|
"""Test creating a prediction in database."""
|
|
# First create a match
|
|
match = Match(
|
|
home_team="PSG",
|
|
away_team="Olympique de Marseille",
|
|
date=datetime.now(timezone.utc),
|
|
league="Ligue 1",
|
|
status="scheduled"
|
|
)
|
|
|
|
db_session.add(match)
|
|
db_session.commit()
|
|
db_session.refresh(match)
|
|
|
|
# Create prediction
|
|
prediction = Prediction(
|
|
match_id=match.id,
|
|
energy_score="high",
|
|
confidence="85%",
|
|
predicted_winner="PSG",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
db_session.add(prediction)
|
|
db_session.commit()
|
|
db_session.refresh(prediction)
|
|
|
|
assert prediction.id is not None
|
|
assert prediction.match_id == match.id
|
|
assert prediction.energy_score == "high"
|
|
assert prediction.confidence == "85%"
|
|
assert prediction.predicted_winner == "PSG"
|
|
|
|
def test_prediction_required_fields(self, db_session: Session):
|
|
"""Test that all required fields must be provided."""
|
|
# Create a match first
|
|
match = Match(
|
|
home_team="Barcelona",
|
|
away_team="Real Madrid",
|
|
date=datetime.now(timezone.utc),
|
|
league="La Liga",
|
|
status="scheduled"
|
|
)
|
|
|
|
db_session.add(match)
|
|
db_session.commit()
|
|
db_session.refresh(match)
|
|
|
|
# Missing predicted_winner
|
|
prediction = Prediction(
|
|
match_id=match.id,
|
|
energy_score="high",
|
|
confidence="90%",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
db_session.add(prediction)
|
|
|
|
with pytest.raises(Exception): # IntegrityError expected
|
|
db_session.commit()
|
|
|
|
def test_prediction_foreign_key_constraint(self, db_session: Session):
|
|
"""Test that match_id must reference an existing match."""
|
|
prediction = Prediction(
|
|
match_id=999, # Non-existent match
|
|
energy_score="high",
|
|
confidence="90%",
|
|
predicted_winner="PSG",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
db_session.add(prediction)
|
|
|
|
with pytest.raises(Exception): # IntegrityError expected
|
|
db_session.commit()
|
|
|
|
def test_prediction_to_dict(self, db_session: Session):
|
|
"""Test converting prediction to dictionary."""
|
|
match = Match(
|
|
home_team="Manchester City",
|
|
away_team="Liverpool",
|
|
date=datetime.now(timezone.utc),
|
|
league="Premier League",
|
|
status="scheduled"
|
|
)
|
|
|
|
db_session.add(match)
|
|
db_session.commit()
|
|
db_session.refresh(match)
|
|
|
|
prediction = Prediction(
|
|
match_id=match.id,
|
|
energy_score="medium",
|
|
confidence="70%",
|
|
predicted_winner="Liverpool",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
db_session.add(prediction)
|
|
db_session.commit()
|
|
db_session.refresh(prediction)
|
|
|
|
prediction_dict = prediction.to_dict()
|
|
|
|
assert prediction_dict['match_id'] == match.id
|
|
assert prediction_dict['energy_score'] == "medium"
|
|
assert prediction_dict['confidence'] == "70%"
|
|
assert prediction_dict['predicted_winner'] == "Liverpool"
|
|
assert 'id' in prediction_dict
|
|
assert 'created_at' in prediction_dict
|
|
|
|
def test_prediction_repr(self, db_session: Session):
|
|
"""Test prediction __repr__ method."""
|
|
match = Match(
|
|
home_team="Juventus",
|
|
away_team="Inter Milan",
|
|
date=datetime.now(timezone.utc),
|
|
league="Serie A",
|
|
status="scheduled"
|
|
)
|
|
|
|
db_session.add(match)
|
|
db_session.commit()
|
|
db_session.refresh(match)
|
|
|
|
prediction = Prediction(
|
|
match_id=match.id,
|
|
energy_score="low",
|
|
confidence="50%",
|
|
predicted_winner="Juventus",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
db_session.add(prediction)
|
|
db_session.commit()
|
|
db_session.refresh(prediction)
|
|
|
|
repr_str = repr(prediction)
|
|
|
|
assert "Prediction" in repr_str
|
|
assert "id=" in repr_str
|
|
assert "match_id=" in repr_str
|
|
assert "confidence=50%" in repr_str
|
|
|
|
def test_prediction_match_relationship(self, db_session: Session):
|
|
"""Test prediction relationship with match."""
|
|
match = Match(
|
|
home_team="Bayern Munich",
|
|
away_team="Borussia Dortmund",
|
|
date=datetime.now(timezone.utc),
|
|
league="Bundesliga",
|
|
status="scheduled"
|
|
)
|
|
|
|
db_session.add(match)
|
|
db_session.commit()
|
|
db_session.refresh(match)
|
|
|
|
prediction = Prediction(
|
|
match_id=match.id,
|
|
energy_score="very_high",
|
|
confidence="95%",
|
|
predicted_winner="Bayern Munich",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
db_session.add(prediction)
|
|
db_session.commit()
|
|
db_session.refresh(prediction)
|
|
|
|
# Access relationship
|
|
assert prediction.match.id == match.id
|
|
assert prediction.match.home_team == "Bayern Munich"
|
|
assert prediction.match.away_team == "Borussia Dortmund"
|
|
|
|
def test_multiple_predictions_per_match(self, db_session: Session):
|
|
"""Test that a match can have multiple predictions."""
|
|
match = Match(
|
|
home_team="Ajax",
|
|
away_team="Feyenoord",
|
|
date=datetime.now(timezone.utc),
|
|
league="Eredivisie",
|
|
status="scheduled"
|
|
)
|
|
|
|
db_session.add(match)
|
|
db_session.commit()
|
|
db_session.refresh(match)
|
|
|
|
# Create multiple predictions
|
|
prediction1 = Prediction(
|
|
match_id=match.id,
|
|
energy_score="high",
|
|
confidence="80%",
|
|
predicted_winner="Ajax",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
prediction2 = Prediction(
|
|
match_id=match.id,
|
|
energy_score="medium",
|
|
confidence="60%",
|
|
predicted_winner="Ajax",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
prediction3 = Prediction(
|
|
match_id=match.id,
|
|
energy_score="low",
|
|
confidence="40%",
|
|
predicted_winner="Feyenoord",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
db_session.add(prediction1)
|
|
db_session.add(prediction2)
|
|
db_session.add(prediction3)
|
|
db_session.commit()
|
|
|
|
# Refresh match and check predictions
|
|
db_session.refresh(match)
|
|
assert len(match.predictions) == 3
|