279 lines
7.2 KiB
Python
279 lines
7.2 KiB
Python
"""
|
|
Energy Score Service.
|
|
|
|
This module provides business logic for energy score calculation and storage.
|
|
"""
|
|
|
|
from datetime import datetime
|
|
from typing import List, Dict, Optional
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import and_
|
|
|
|
from app.ml.energy_calculator import (
|
|
calculate_energy_score,
|
|
adjust_weights_for_degraded_mode,
|
|
get_source_weights
|
|
)
|
|
from app.models.energy_score import EnergyScore
|
|
from app.schemas.energy_score import (
|
|
EnergyScoreCreate,
|
|
EnergyScoreUpdate,
|
|
EnergyScoreCalculationRequest,
|
|
EnergyScoreCalculationResponse
|
|
)
|
|
from app.database import get_db
|
|
|
|
|
|
def calculate_and_store_energy_score(
|
|
db: Session,
|
|
request: EnergyScoreCalculationRequest
|
|
) -> EnergyScore:
|
|
"""
|
|
Calculate energy score and store it in the database.
|
|
|
|
Args:
|
|
db: Database session
|
|
request: Energy score calculation request
|
|
|
|
Returns:
|
|
Created EnergyScore object
|
|
"""
|
|
# Calculate energy score using the ML module
|
|
result = calculate_energy_score(
|
|
match_id=request.match_id,
|
|
team_id=request.team_id,
|
|
twitter_sentiments=request.twitter_sentiments or [],
|
|
reddit_sentiments=request.reddit_sentiments or [],
|
|
rss_sentiments=request.rss_sentiments or [],
|
|
tweets_with_timestamps=request.tweets_with_timestamps or []
|
|
)
|
|
|
|
# Get adjusted weights for degraded mode tracking
|
|
available_sources = result['sources_used']
|
|
original_weights = get_source_weights()
|
|
adjusted_weights = adjust_weights_for_degraded_mode(
|
|
original_weights=original_weights,
|
|
available_sources=available_sources
|
|
)
|
|
|
|
# Create energy score record
|
|
energy_score = EnergyScore(
|
|
match_id=request.match_id,
|
|
team_id=request.team_id,
|
|
score=result['score'],
|
|
confidence=result['confidence'],
|
|
sources_used=result['sources_used'],
|
|
twitter_score=_calculate_component_score(request.twitter_sentiments),
|
|
reddit_score=_calculate_component_score(request.reddit_sentiments),
|
|
rss_score=_calculate_component_score(request.rss_sentiments),
|
|
temporal_factor=result.get('temporal_factor'),
|
|
twitter_weight=adjusted_weights.get('twitter'),
|
|
reddit_weight=adjusted_weights.get('reddit'),
|
|
rss_weight=adjusted_weights.get('rss'),
|
|
created_at=datetime.utcnow(),
|
|
updated_at=datetime.utcnow()
|
|
)
|
|
|
|
# Save to database
|
|
db.add(energy_score)
|
|
db.commit()
|
|
db.refresh(energy_score)
|
|
|
|
return energy_score
|
|
|
|
|
|
def _calculate_component_score(sentiments: Optional[List[Dict]]) -> Optional[float]:
|
|
"""
|
|
Calculate component score for a single source.
|
|
|
|
Args:
|
|
sentiments: List of sentiment scores
|
|
|
|
Returns:
|
|
Component score or None if no sentiments
|
|
"""
|
|
if not sentiments:
|
|
return None
|
|
|
|
# Simple average of compound scores
|
|
total = sum(s.get('compound', 0) for s in sentiments)
|
|
return total / len(sentiments) if sentiments else None
|
|
|
|
|
|
def get_energy_score(
|
|
db: Session,
|
|
energy_score_id: int
|
|
) -> Optional[EnergyScore]:
|
|
"""
|
|
Get an energy score by ID.
|
|
|
|
Args:
|
|
db: Database session
|
|
energy_score_id: ID of the energy score
|
|
|
|
Returns:
|
|
EnergyScore object or None
|
|
"""
|
|
return db.query(EnergyScore).filter(EnergyScore.id == energy_score_id).first()
|
|
|
|
|
|
def get_energy_scores_by_match(
|
|
db: Session,
|
|
match_id: int
|
|
) -> List[EnergyScore]:
|
|
"""
|
|
Get all energy scores for a specific match.
|
|
|
|
Args:
|
|
db: Database session
|
|
match_id: ID of the match
|
|
|
|
Returns:
|
|
List of EnergyScore objects
|
|
"""
|
|
return db.query(EnergyScore).filter(EnergyScore.match_id == match_id).all()
|
|
|
|
|
|
def get_energy_scores_by_team(
|
|
db: Session,
|
|
team_id: int
|
|
) -> List[EnergyScore]:
|
|
"""
|
|
Get all energy scores for a specific team.
|
|
|
|
Args:
|
|
db: Database session
|
|
team_id: ID of the team
|
|
|
|
Returns:
|
|
List of EnergyScore objects
|
|
"""
|
|
return db.query(EnergyScore).filter(EnergyScore.team_id == team_id).all()
|
|
|
|
|
|
def get_energy_score_by_match_and_team(
|
|
db: Session,
|
|
match_id: int,
|
|
team_id: int
|
|
) -> Optional[EnergyScore]:
|
|
"""
|
|
Get the most recent energy score for a specific match and team.
|
|
|
|
Args:
|
|
db: Database session
|
|
match_id: ID of the match
|
|
team_id: ID of the team
|
|
|
|
Returns:
|
|
EnergyScore object or None
|
|
"""
|
|
return (db.query(EnergyScore)
|
|
.filter(and_(EnergyScore.match_id == match_id, EnergyScore.team_id == team_id))
|
|
.order_by(EnergyScore.created_at.desc())
|
|
.first())
|
|
|
|
|
|
def update_energy_score(
|
|
db: Session,
|
|
energy_score_id: int,
|
|
update: EnergyScoreUpdate
|
|
) -> Optional[EnergyScore]:
|
|
"""
|
|
Update an existing energy score.
|
|
|
|
Args:
|
|
db: Database session
|
|
energy_score_id: ID of the energy score
|
|
update: Updated energy score data
|
|
|
|
Returns:
|
|
Updated EnergyScore object or None
|
|
"""
|
|
energy_score = get_energy_score(db, energy_score_id)
|
|
if not energy_score:
|
|
return None
|
|
|
|
# Update fields
|
|
update_data = update.model_dump(exclude_unset=True)
|
|
for key, value in update_data.items():
|
|
setattr(energy_score, key, value)
|
|
|
|
energy_score.updated_at = datetime.utcnow()
|
|
|
|
# Save to database
|
|
db.commit()
|
|
db.refresh(energy_score)
|
|
|
|
return energy_score
|
|
|
|
|
|
def delete_energy_score(
|
|
db: Session,
|
|
energy_score_id: int
|
|
) -> bool:
|
|
"""
|
|
Delete an energy score.
|
|
|
|
Args:
|
|
db: Database session
|
|
energy_score_id: ID of the energy score
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
energy_score = get_energy_score(db, energy_score_id)
|
|
if not energy_score:
|
|
return False
|
|
|
|
db.delete(energy_score)
|
|
db.commit()
|
|
|
|
return True
|
|
|
|
|
|
def list_energy_scores(
|
|
db: Session,
|
|
match_id: Optional[int] = None,
|
|
team_id: Optional[int] = None,
|
|
min_score: Optional[float] = None,
|
|
max_score: Optional[float] = None,
|
|
min_confidence: Optional[float] = None,
|
|
limit: int = 10,
|
|
offset: int = 0
|
|
) -> List[EnergyScore]:
|
|
"""
|
|
List energy scores with optional filters.
|
|
|
|
Args:
|
|
db: Database session
|
|
match_id: Optional filter by match ID
|
|
team_id: Optional filter by team ID
|
|
min_score: Optional filter by minimum score
|
|
max_score: Optional filter by maximum score
|
|
min_confidence: Optional filter by minimum confidence
|
|
limit: Maximum number of results
|
|
offset: Offset for pagination
|
|
|
|
Returns:
|
|
List of EnergyScore objects
|
|
"""
|
|
query = db.query(EnergyScore)
|
|
|
|
# Apply filters
|
|
if match_id is not None:
|
|
query = query.filter(EnergyScore.match_id == match_id)
|
|
if team_id is not None:
|
|
query = query.filter(EnergyScore.team_id == team_id)
|
|
if min_score is not None:
|
|
query = query.filter(EnergyScore.score >= min_score)
|
|
if max_score is not None:
|
|
query = query.filter(EnergyScore.score <= max_score)
|
|
if min_confidence is not None:
|
|
query = query.filter(EnergyScore.confidence >= min_confidence)
|
|
|
|
# Apply pagination and ordering
|
|
query = query.order_by(EnergyScore.created_at.desc())
|
|
query = query.offset(offset).limit(limit)
|
|
|
|
return query.all()
|