""" Energy Score Service. This module provides business logic for energy score calculation and storage. """ from datetime import datetime from typing import List, Dict, Optional from sqlalchemy.orm import Session from sqlalchemy import and_ from app.ml.energy_calculator import ( calculate_energy_score, adjust_weights_for_degraded_mode, get_source_weights ) from app.models.energy_score import EnergyScore from app.schemas.energy_score import ( EnergyScoreCreate, EnergyScoreUpdate, EnergyScoreCalculationRequest, EnergyScoreCalculationResponse ) from app.database import get_db def calculate_and_store_energy_score( db: Session, request: EnergyScoreCalculationRequest ) -> EnergyScore: """ Calculate energy score and store it in the database. Args: db: Database session request: Energy score calculation request Returns: Created EnergyScore object """ # Calculate energy score using the ML module result = calculate_energy_score( match_id=request.match_id, team_id=request.team_id, twitter_sentiments=request.twitter_sentiments or [], reddit_sentiments=request.reddit_sentiments or [], rss_sentiments=request.rss_sentiments or [], tweets_with_timestamps=request.tweets_with_timestamps or [] ) # Get adjusted weights for degraded mode tracking available_sources = result['sources_used'] original_weights = get_source_weights() adjusted_weights = adjust_weights_for_degraded_mode( original_weights=original_weights, available_sources=available_sources ) # Create energy score record energy_score = EnergyScore( match_id=request.match_id, team_id=request.team_id, score=result['score'], confidence=result['confidence'], sources_used=result['sources_used'], twitter_score=_calculate_component_score(request.twitter_sentiments), reddit_score=_calculate_component_score(request.reddit_sentiments), rss_score=_calculate_component_score(request.rss_sentiments), temporal_factor=result.get('temporal_factor'), twitter_weight=adjusted_weights.get('twitter'), reddit_weight=adjusted_weights.get('reddit'), rss_weight=adjusted_weights.get('rss'), created_at=datetime.utcnow(), updated_at=datetime.utcnow() ) # Save to database db.add(energy_score) db.commit() db.refresh(energy_score) return energy_score def _calculate_component_score(sentiments: Optional[List[Dict]]) -> Optional[float]: """ Calculate component score for a single source. Args: sentiments: List of sentiment scores Returns: Component score or None if no sentiments """ if not sentiments: return None # Simple average of compound scores total = sum(s.get('compound', 0) for s in sentiments) return total / len(sentiments) if sentiments else None def get_energy_score( db: Session, energy_score_id: int ) -> Optional[EnergyScore]: """ Get an energy score by ID. Args: db: Database session energy_score_id: ID of the energy score Returns: EnergyScore object or None """ return db.query(EnergyScore).filter(EnergyScore.id == energy_score_id).first() def get_energy_scores_by_match( db: Session, match_id: int ) -> List[EnergyScore]: """ Get all energy scores for a specific match. Args: db: Database session match_id: ID of the match Returns: List of EnergyScore objects """ return db.query(EnergyScore).filter(EnergyScore.match_id == match_id).all() def get_energy_scores_by_team( db: Session, team_id: int ) -> List[EnergyScore]: """ Get all energy scores for a specific team. Args: db: Database session team_id: ID of the team Returns: List of EnergyScore objects """ return db.query(EnergyScore).filter(EnergyScore.team_id == team_id).all() def get_energy_score_by_match_and_team( db: Session, match_id: int, team_id: int ) -> Optional[EnergyScore]: """ Get the most recent energy score for a specific match and team. Args: db: Database session match_id: ID of the match team_id: ID of the team Returns: EnergyScore object or None """ return (db.query(EnergyScore) .filter(and_(EnergyScore.match_id == match_id, EnergyScore.team_id == team_id)) .order_by(EnergyScore.created_at.desc()) .first()) def update_energy_score( db: Session, energy_score_id: int, update: EnergyScoreUpdate ) -> Optional[EnergyScore]: """ Update an existing energy score. Args: db: Database session energy_score_id: ID of the energy score update: Updated energy score data Returns: Updated EnergyScore object or None """ energy_score = get_energy_score(db, energy_score_id) if not energy_score: return None # Update fields update_data = update.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(energy_score, key, value) energy_score.updated_at = datetime.utcnow() # Save to database db.commit() db.refresh(energy_score) return energy_score def delete_energy_score( db: Session, energy_score_id: int ) -> bool: """ Delete an energy score. Args: db: Database session energy_score_id: ID of the energy score Returns: True if deleted, False if not found """ energy_score = get_energy_score(db, energy_score_id) if not energy_score: return False db.delete(energy_score) db.commit() return True def list_energy_scores( db: Session, match_id: Optional[int] = None, team_id: Optional[int] = None, min_score: Optional[float] = None, max_score: Optional[float] = None, min_confidence: Optional[float] = None, limit: int = 10, offset: int = 0 ) -> List[EnergyScore]: """ List energy scores with optional filters. Args: db: Database session match_id: Optional filter by match ID team_id: Optional filter by team ID min_score: Optional filter by minimum score max_score: Optional filter by maximum score min_confidence: Optional filter by minimum confidence limit: Maximum number of results offset: Offset for pagination Returns: List of EnergyScore objects """ query = db.query(EnergyScore) # Apply filters if match_id is not None: query = query.filter(EnergyScore.match_id == match_id) if team_id is not None: query = query.filter(EnergyScore.team_id == team_id) if min_score is not None: query = query.filter(EnergyScore.score >= min_score) if max_score is not None: query = query.filter(EnergyScore.score <= max_score) if min_confidence is not None: query = query.filter(EnergyScore.confidence >= min_confidence) # Apply pagination and ordering query = query.order_by(EnergyScore.created_at.desc()) query = query.offset(offset).limit(limit) return query.all()