305 lines
10 KiB
Python
305 lines
10 KiB
Python
"""
|
|
Prediction Service Module.
|
|
|
|
This module provides business logic for creating and managing match predictions.
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.ml.prediction_calculator import calculate_prediction, validate_prediction_result
|
|
from app.models.match import Match
|
|
from app.models.prediction import Prediction
|
|
from app.schemas.prediction import PredictionCreate, MatchInfo
|
|
|
|
|
|
class PredictionService:
|
|
"""Service for handling prediction business logic."""
|
|
|
|
def __init__(self, db: Session):
|
|
"""
|
|
Initialize the prediction service.
|
|
|
|
Args:
|
|
db: SQLAlchemy database session
|
|
"""
|
|
self.db = db
|
|
|
|
def create_prediction_for_match(
|
|
self,
|
|
match_id: int,
|
|
home_energy: float,
|
|
away_energy: float,
|
|
energy_score_label: Optional[str] = None
|
|
) -> Prediction:
|
|
"""
|
|
Create a prediction for a specific match based on energy scores.
|
|
|
|
This method:
|
|
1. Validates that the match exists
|
|
2. Calculates the prediction using energy scores
|
|
3. Stores the prediction in the database
|
|
4. Returns the created prediction
|
|
|
|
Args:
|
|
match_id: ID of the match to predict
|
|
home_energy: Energy score of the home team
|
|
away_energy: Energy score of the away team
|
|
energy_score_label: Optional label for energy score (e.g., "high", "medium", "low")
|
|
|
|
Returns:
|
|
Created Prediction object
|
|
|
|
Raises:
|
|
ValueError: If match doesn't exist or energy scores are invalid
|
|
"""
|
|
# Validate match exists
|
|
match = self.db.query(Match).filter(Match.id == match_id).first()
|
|
if not match:
|
|
raise ValueError(f"Match with id {match_id} not found")
|
|
|
|
# Validate energy scores
|
|
if not isinstance(home_energy, (int, float)) or not isinstance(away_energy, (int, float)):
|
|
raise ValueError("Energy scores must be numeric values")
|
|
|
|
if home_energy < 0 or away_energy < 0:
|
|
raise ValueError("Energy scores cannot be negative")
|
|
|
|
# Calculate prediction
|
|
prediction_result = calculate_prediction(home_energy, away_energy)
|
|
|
|
# Validate prediction result
|
|
if not validate_prediction_result(prediction_result):
|
|
raise ValueError("Invalid prediction calculation result")
|
|
|
|
# Determine energy score label if not provided
|
|
if energy_score_label is None:
|
|
avg_energy = (home_energy + away_energy) / 2
|
|
if avg_energy >= 70:
|
|
energy_score_label = "very_high"
|
|
elif avg_energy >= 50:
|
|
energy_score_label = "high"
|
|
elif avg_energy >= 30:
|
|
energy_score_label = "medium"
|
|
else:
|
|
energy_score_label = "low"
|
|
|
|
# Determine predicted winner team name
|
|
if prediction_result['predicted_winner'] == 'home':
|
|
predicted_winner_name = match.home_team
|
|
elif prediction_result['predicted_winner'] == 'away':
|
|
predicted_winner_name = match.away_team
|
|
else:
|
|
predicted_winner_name = "Draw"
|
|
|
|
# Create prediction object
|
|
prediction = Prediction(
|
|
match_id=match_id,
|
|
energy_score=energy_score_label,
|
|
confidence=f"{prediction_result['confidence']:.1f}%",
|
|
predicted_winner=predicted_winner_name,
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
# Save to database
|
|
self.db.add(prediction)
|
|
self.db.commit()
|
|
self.db.refresh(prediction)
|
|
|
|
return prediction
|
|
|
|
def get_prediction_by_id(self, prediction_id: int) -> Optional[Prediction]:
|
|
"""
|
|
Get a prediction by its ID.
|
|
|
|
Args:
|
|
prediction_id: ID of the prediction to retrieve
|
|
|
|
Returns:
|
|
Prediction object or None if not found
|
|
"""
|
|
return self.db.query(Prediction).filter(Prediction.id == prediction_id).first()
|
|
|
|
def get_predictions_for_match(self, match_id: int) -> list[Prediction]:
|
|
"""
|
|
Get all predictions for a specific match.
|
|
|
|
Args:
|
|
match_id: ID of the match
|
|
|
|
Returns:
|
|
List of Prediction objects
|
|
"""
|
|
return self.db.query(Prediction).filter(Prediction.match_id == match_id).all()
|
|
|
|
def get_latest_prediction_for_match(self, match_id: int) -> Optional[Prediction]:
|
|
"""
|
|
Get the most recent prediction for a match.
|
|
|
|
Args:
|
|
match_id: ID of the match
|
|
|
|
Returns:
|
|
Latest Prediction object or None if no predictions exist
|
|
"""
|
|
return (
|
|
self.db.query(Prediction)
|
|
.filter(Prediction.match_id == match_id)
|
|
.order_by(Prediction.created_at.desc())
|
|
.first()
|
|
)
|
|
|
|
def delete_prediction(self, prediction_id: int) -> bool:
|
|
"""
|
|
Delete a prediction by its ID.
|
|
|
|
Args:
|
|
prediction_id: ID of the prediction to delete
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
prediction = self.db.query(Prediction).filter(Prediction.id == prediction_id).first()
|
|
if prediction:
|
|
self.db.delete(prediction)
|
|
self.db.commit()
|
|
return True
|
|
return False
|
|
|
|
def get_predictions_with_pagination(
|
|
self,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
team_id: Optional[int] = None,
|
|
league: Optional[str] = None,
|
|
date_min: Optional[datetime] = None,
|
|
date_max: Optional[datetime] = None
|
|
) -> tuple[list[Prediction], int]:
|
|
"""
|
|
Get predictions with pagination and filters.
|
|
|
|
This method retrieves predictions joined with match data, applies filters,
|
|
and returns paginated results.
|
|
|
|
Args:
|
|
limit: Maximum number of predictions to return (max 100)
|
|
offset: Number of predictions to skip
|
|
team_id: Optional filter by team ID (home or away)
|
|
league: Optional filter by league name
|
|
date_min: Optional filter for matches after this date
|
|
date_max: Optional filter for matches before this date
|
|
|
|
Returns:
|
|
Tuple of (list of predictions, total count)
|
|
"""
|
|
# Start with a query that includes match data
|
|
query = (
|
|
self.db.query(Prediction)
|
|
.join(Match)
|
|
)
|
|
|
|
# Apply filters
|
|
if team_id:
|
|
# Get the match for the team_id to get team names
|
|
team_match = self.db.query(Match).filter(Match.id == team_id).first()
|
|
if team_match:
|
|
query = query.filter(
|
|
(Match.home_team == team_match.home_team) |
|
|
(Match.away_team == team_match.away_team)
|
|
)
|
|
|
|
if league:
|
|
query = query.filter(Match.league.ilike(f"%{league}%"))
|
|
|
|
if date_min:
|
|
query = query.filter(Match.date >= date_min)
|
|
|
|
if date_max:
|
|
query = query.filter(Match.date <= date_max)
|
|
|
|
# Get total count before pagination
|
|
total = query.count()
|
|
|
|
# Apply pagination and ordering by match date (upcoming matches first)
|
|
predictions = (
|
|
query
|
|
.order_by(Match.date.asc())
|
|
.limit(min(limit, 100))
|
|
.offset(offset)
|
|
.all()
|
|
)
|
|
|
|
return predictions, total
|
|
|
|
def get_prediction_with_details(self, match_id: int) -> Optional[dict]:
|
|
"""
|
|
Get a prediction for a specific match with full details.
|
|
|
|
This method retrieves latest prediction for a match and includes
|
|
match details, energy score information, and historical data.
|
|
|
|
Args:
|
|
match_id: ID of match
|
|
|
|
Returns:
|
|
Dictionary with prediction details or None if not found
|
|
"""
|
|
prediction = self.get_latest_prediction_for_match(match_id)
|
|
|
|
if not prediction:
|
|
return None
|
|
|
|
# Get match details
|
|
match = prediction.match
|
|
|
|
# Build response with all details
|
|
result = {
|
|
"id": prediction.id,
|
|
"match_id": prediction.match_id,
|
|
"match": {
|
|
"id": match.id,
|
|
"home_team": match.home_team,
|
|
"away_team": match.away_team,
|
|
"date": match.date.isoformat() if match.date else None,
|
|
"league": match.league,
|
|
"status": match.status,
|
|
"actual_winner": match.actual_winner
|
|
},
|
|
"energy_score": prediction.energy_score,
|
|
"confidence": prediction.confidence,
|
|
"predicted_winner": prediction.predicted_winner,
|
|
"created_at": prediction.created_at.isoformat() if prediction.created_at else None,
|
|
"history": self._get_prediction_history(match_id)
|
|
}
|
|
|
|
return result
|
|
|
|
def _get_prediction_history(self, match_id: int) -> list[dict]:
|
|
"""
|
|
Get historical predictions for a match.
|
|
|
|
Args:
|
|
match_id: ID of match
|
|
|
|
Returns:
|
|
List of historical predictions (all predictions for match)
|
|
"""
|
|
predictions = (
|
|
self.db.query(Prediction)
|
|
.filter(Prediction.match_id == match_id)
|
|
.order_by(Prediction.created_at.desc())
|
|
.all()
|
|
)
|
|
|
|
return [
|
|
{
|
|
"id": pred.id,
|
|
"energy_score": pred.energy_score,
|
|
"confidence": pred.confidence,
|
|
"predicted_winner": pred.predicted_winner,
|
|
"created_at": pred.created_at.isoformat() if pred.created_at else None
|
|
}
|
|
for pred in predictions
|
|
]
|