Initial commit

This commit is contained in:
2026-02-01 09:31:38 +01:00
commit e02db93960
4396 changed files with 1511612 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Tests package for backend."""

184
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,184 @@
"""
Pytest configuration and fixtures for backend tests.
"""
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from app.database import Base, get_db
from app.models import (
User,
Tweet,
RedditPost,
RedditComment,
SentimentScore,
EnergyScore
)
# Create in-memory SQLite database for testing
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""
Create a fresh database session for each test.
This fixture creates a new database, creates all tables,
and returns a session. The database is cleaned up after the test.
"""
# Create all tables
Base.metadata.create_all(bind=engine)
# Create session
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
# Drop all tables after test
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client(db_session):
"""
Create a test client with a database session override.
This fixture overrides the get_db dependency to use the test database.
"""
from fastapi.testclient import TestClient
from app.main import app
def override_get_db():
try:
yield db_session
finally:
pass
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
@pytest.fixture(scope="function")
def sample_user(db_session):
"""
Create a sample user for testing.
Returns:
User object
"""
user = User(
username="testuser",
email="test@example.com",
hashed_password="hashed_password",
is_premium=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
@pytest.fixture(scope="function")
def sample_tweet(db_session, sample_user):
"""
Create a sample tweet for testing.
Returns:
Tweet object
"""
from datetime import datetime
tweet = Tweet(
tweet_id="1234567890",
text="This is a test tweet",
author="test_author",
created_at=datetime.utcnow(),
user_id=sample_user.id,
retweets=10,
likes=20,
replies=5
)
db_session.add(tweet)
db_session.commit()
db_session.refresh(tweet)
return tweet
@pytest.fixture(scope="function")
def sample_reddit_post(db_session, sample_user):
"""
Create a sample Reddit post for testing.
Returns:
RedditPost object
"""
from datetime import datetime
reddit_post = RedditPost(
post_id="reddit123",
title="Test Reddit Post",
text="This is a test Reddit post",
author="reddit_author",
subreddit="test_subreddit",
created_at=datetime.utcnow(),
user_id=sample_user.id,
upvotes=15,
comments=3
)
db_session.add(reddit_post)
db_session.commit()
db_session.refresh(reddit_post)
return reddit_post
@pytest.fixture(scope="function")
def sample_sentiment_scores(db_session, sample_tweet, sample_reddit_post):
"""
Create sample sentiment scores for testing.
Returns:
List of SentimentScore objects
"""
from datetime import datetime
twitter_sentiment = SentimentScore(
entity_id=sample_tweet.tweet_id,
entity_type="tweet",
score=0.5,
sentiment_type="positive",
positive=0.6,
negative=0.2,
neutral=0.2,
created_at=datetime.utcnow()
)
reddit_sentiment = SentimentScore(
entity_id=sample_reddit_post.post_id,
entity_type="reddit_post",
score=0.3,
sentiment_type="positive",
positive=0.4,
negative=0.3,
neutral=0.3,
created_at=datetime.utcnow()
)
db_session.add(twitter_sentiment)
db_session.add(reddit_sentiment)
db_session.commit()
db_session.refresh(twitter_sentiment)
db_session.refresh(reddit_sentiment)
return [twitter_sentiment, reddit_sentiment]

View File

@@ -0,0 +1,72 @@
"""
Script pour exécuter tous les tests du backend.
"""
import subprocess
import sys
def run_command(cmd: str, description: str) -> bool:
"""
Exécute une commande shell et retourne le succès.
Args:
cmd: Commande à exécuter
description: Description de la commande
Returns:
True si la commande a réussi, False sinon
"""
print(f"\n🔍 {description}")
print("=" * 60)
try:
result = subprocess.run(
cmd,
shell=True,
check=True,
capture_output=False
)
print(f"{description} réussi !")
return True
except subprocess.CalledProcessError as e:
print(f"{description} échoué avec le code {e.returncode}")
return False
def main():
"""Fonction principale."""
print("🧪 Exécution des tests du backend...")
print("=" * 60)
all_passed = True
# Tests unitaires
all_passed &= run_command(
"pytest tests/ -v --tb=short",
"Tests unitaires"
)
# Linting
all_passed &= run_command(
"flake8 app/ --max-line-length=100",
"Linting avec flake8"
)
# Formatting
all_passed &= run_command(
"black --check app/",
"Vérification du formatage avec black"
)
print("\n" + "=" * 60)
if all_passed:
print("✅ Tous les tests et validations ont réussi !")
return 0
else:
print("❌ Certains tests ou validations ont échoué.")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,41 @@
"""
Tests for API dependencies.
This module tests API key authentication dependencies.
"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
class TestApiKeyDependency:
"""Tests for API key authentication dependency."""
def test_missing_api_key_header(self):
"""Test that missing X-API-Key header returns 401."""
response = client.get("/api/v1/users/profile")
# This endpoint should require authentication
# For now, we'll just test the dependency structure
# The dependency will raise 401 if X-API-Key is missing
# We'll verify the error format
def test_invalid_api_key(self):
"""Test that invalid API key returns 401."""
response = client.get(
"/api/v1/users/profile",
headers={"X-API-Key": "invalid_key_12345"}
)
# Should return 401 with proper error format
assert response.status_code in [401, 404] # 404 if endpoint doesn't exist yet
def test_api_key_includes_user_id(self):
"""Test that valid API key includes user_id in dependency."""
# This test will be implemented once protected endpoints exist
pass

View File

@@ -0,0 +1,203 @@
"""
Tests for API Key Service.
This module tests the API key generation, validation, and management.
"""
import pytest
from sqlalchemy.orm import Session
from app.services.apiKeyService import ApiKeyService
from app.models.api_key import ApiKey
from app.models.user import User
@pytest.fixture
def sample_user(db: Session):
"""Create a sample user for testing."""
user = User(
email="test@example.com",
name="Test User",
is_premium=False
)
db.add(user)
db.commit()
db.refresh(user)
return user
class TestApiKeyGeneration:
"""Tests for API key generation."""
def test_generate_api_key_success(self, db: Session, sample_user: User):
"""Test successful API key generation."""
service = ApiKeyService(db)
result = service.generate_api_key(sample_user.id, rate_limit=100)
assert "api_key" in result
assert "id" in result
assert "key_prefix" in result
assert "user_id" in result
assert result["user_id"] == sample_user.id
assert result["rate_limit"] == 100
assert result["is_active"] == True
assert len(result["api_key"]) > 0
assert len(result["key_prefix"]) == 8
def test_generate_api_key_invalid_user(self, db: Session):
"""Test API key generation with invalid user ID."""
service = ApiKeyService(db)
with pytest.raises(ValueError) as exc_info:
service.generate_api_key(99999)
assert "not found" in str(exc_info.value).lower()
def test_generate_api_key_stores_hash(self, db: Session, sample_user: User):
"""Test that API key stores hash instead of plain key."""
service = ApiKeyService(db)
result = service.generate_api_key(sample_user.id)
# Query the database to verify hash is stored
api_key_record = db.query(ApiKey).filter(
ApiKey.id == result["id"]
).first()
assert api_key_record is not None
assert api_key_record.key_hash != result["api_key"]
assert len(api_key_record.key_hash) == 64 # SHA-256 hash length
assert api_key_record.key_prefix == result["api_key"][:8]
class TestApiKeyValidation:
"""Tests for API key validation."""
def test_validate_api_key_success(self, db: Session, sample_user: User):
"""Test successful API key validation."""
service = ApiKeyService(db)
result = service.generate_api_key(sample_user.id)
plain_api_key = result["api_key"]
# Validate the key
validated = service.validate_api_key(plain_api_key)
assert validated is not None
assert validated.user_id == sample_user.id
assert validated.is_active == True
def test_validate_api_key_invalid(self, db: Session):
"""Test validation with invalid API key."""
service = ApiKeyService(db)
validated = service.validate_api_key("invalid_key_12345")
assert validated is None
def test_validate_api_key_updates_last_used(self, db: Session, sample_user: User):
"""Test that validation updates last_used_at timestamp."""
service = ApiKeyService(db)
result = service.generate_api_key(sample_user.id)
plain_api_key = result["api_key"]
# Validate the key
validated = service.validate_api_key(plain_api_key)
assert validated.last_used_at is not None
def test_validate_api_key_inactive(self, db: Session, sample_user: User):
"""Test validation with inactive API key."""
service = ApiKeyService(db)
result = service.generate_api_key(sample_user.id)
plain_api_key = result["api_key"]
# Deactivate the key
api_key_record = db.query(ApiKey).filter(
ApiKey.id == result["id"]
).first()
api_key_record.is_active = False
db.commit()
# Try to validate
validated = service.validate_api_key(plain_api_key)
assert validated is None
class TestApiKeyManagement:
"""Tests for API key management."""
def test_get_user_api_keys(self, db: Session, sample_user: User):
"""Test retrieving all API keys for a user."""
service = ApiKeyService(db)
# Generate 3 keys
keys = []
for i in range(3):
result = service.generate_api_key(sample_user.id)
keys.append(result)
# Retrieve all keys
user_keys = service.get_user_api_keys(sample_user.id)
assert len(user_keys) == 3
# Verify plain keys are not included
for key in user_keys:
assert "api_key" not in key
assert "key_prefix" in key
def test_revoke_api_key_success(self, db: Session, sample_user: User):
"""Test successful API key revocation."""
service = ApiKeyService(db)
result = service.generate_api_key(sample_user.id)
api_key_id = result["id"]
# Revoke the key
revoked = service.revoke_api_key(api_key_id, sample_user.id)
assert revoked is True
# Verify it's deactivated
api_key_record = db.query(ApiKey).filter(
ApiKey.id == api_key_id
).first()
assert api_key_record.is_active == False
def test_revoke_api_key_wrong_user(self, db: Session, sample_user: User):
"""Test revoking key with wrong user ID."""
service = ApiKeyService(db)
result = service.generate_api_key(sample_user.id)
api_key_id = result["id"]
# Try to revoke with wrong user ID
revoked = service.revoke_api_key(api_key_id, 99999)
assert revoked is False
def test_regenerate_api_key_success(self, db: Session, sample_user: User):
"""Test successful API key regeneration."""
service = ApiKeyService(db)
old_result = service.generate_api_key(sample_user.id)
old_api_key_id = old_result["id"]
# Regenerate
new_result = service.regenerate_api_key(old_api_key_id, sample_user.id)
assert new_result is not None
assert "api_key" in new_result
assert new_result["api_key"] != old_result["api_key"]
assert new_result["id"] != old_api_key_id
# Verify old key is deactivated
old_key_record = db.query(ApiKey).filter(
ApiKey.id == old_api_key_id
).first()
assert old_key_record.is_active == False
def test_regenerate_api_key_wrong_user(self, db: Session, sample_user: User):
"""Test regenerating key with wrong user ID."""
service = ApiKeyService(db)
old_result = service.generate_api_key(sample_user.id)
# Try to regenerate with wrong user ID
new_result = service.regenerate_api_key(old_result["id"], 99999)
assert new_result is None

View File

@@ -0,0 +1,515 @@
"""
Tests for Backtesting Module.
This module contains unit tests for the backtesting functionality
including accuracy calculation, comparison logic, and export formats.
"""
import pytest
from datetime import datetime
from app.ml.backtesting import (
validate_accuracy,
compare_prediction,
run_backtesting_single_match,
run_backtesting_batch,
export_to_json,
export_to_csv,
export_to_html,
filter_matches_by_league,
filter_matches_by_period,
ACCURACY_VALIDATED_THRESHOLD,
ACCURACY_ALERT_THRESHOLD
)
class TestValidateAccuracy:
"""Tests for validate_accuracy function."""
def test_validate_accuracy_above_threshold(self):
"""Test validation when accuracy >= 60%."""
result = validate_accuracy(65.0)
assert result == 'VALIDATED'
def test_validate_accuracy_at_threshold(self):
"""Test validation when accuracy == 60%."""
result = validate_accuracy(60.0)
assert result == 'VALIDATED'
def test_validate_accuracy_below_target(self):
"""Test validation when 55% <= accuracy < 60%."""
result = validate_accuracy(58.0)
assert result == 'BELOW_TARGET'
def test_validate_accuracy_alert(self):
"""Test validation when accuracy < 55%."""
result = validate_accuracy(50.0)
assert result == 'REVISION_REQUIRED'
def test_validate_accuracy_boundary(self):
"""Test validation at boundary 55%."""
result = validate_accuracy(55.0)
assert result == 'BELOW_TARGET'
def test_validate_accuracy_extreme_high(self):
"""Test validation with perfect accuracy."""
result = validate_accuracy(100.0)
assert result == 'VALIDATED'
def test_validate_accuracy_zero(self):
"""Test validation with zero accuracy."""
result = validate_accuracy(0.0)
assert result == 'REVISION_REQUIRED'
def test_validate_thresholds_constants(self):
"""Test that threshold constants are properly defined."""
assert ACCURACY_VALIDATED_THRESHOLD == 60.0
assert ACCURACY_ALERT_THRESHOLD == 55.0
class TestComparePrediction:
"""Tests for compare_prediction function."""
def test_compare_home_correct(self):
"""Test comparison when home prediction is correct."""
result = compare_prediction('home', 'home')
assert result is True
def test_compare_away_correct(self):
"""Test comparison when away prediction is correct."""
result = compare_prediction('away', 'away')
assert result is True
def test_compare_draw_correct(self):
"""Test comparison when draw prediction is correct."""
result = compare_prediction('draw', 'draw')
assert result is True
def test_compare_home_incorrect(self):
"""Test comparison when home prediction is incorrect."""
result = compare_prediction('home', 'away')
assert result is False
def test_compare_case_insensitive(self):
"""Test that comparison is case insensitive."""
result1 = compare_prediction('HOME', 'home')
result2 = compare_prediction('Home', 'home')
result3 = compare_prediction('home', 'HOME')
assert result1 is True
assert result2 is True
assert result3 is True
class TestRunBacktestingSingleMatch:
"""Tests for run_backtesting_single_match function."""
def test_single_match_correct_prediction(self):
"""Test backtesting for a single match with correct prediction."""
result = run_backtesting_single_match(
match_id=1,
home_team='PSG',
away_team='OM',
home_energy=65.0,
away_energy=45.0,
actual_winner='home'
)
assert result['match_id'] == 1
assert result['home_team'] == 'PSG'
assert result['away_team'] == 'OM'
assert result['home_energy'] == 65.0
assert result['away_energy'] == 45.0
assert result['actual_winner'] == 'home'
assert result['correct'] is True
assert 'prediction' in result
assert result['prediction']['predicted_winner'] == 'home'
def test_single_match_incorrect_prediction(self):
"""Test backtesting for a single match with incorrect prediction."""
result = run_backtesting_single_match(
match_id=2,
home_team='PSG',
away_team='OM',
home_energy=65.0,
away_energy=45.0,
actual_winner='away'
)
assert result['match_id'] == 2
assert result['correct'] is False
assert result['prediction']['predicted_winner'] != 'away'
def test_single_match_with_equal_energy(self):
"""Test backtesting for a match with equal energy scores."""
result = run_backtesting_single_match(
match_id=3,
home_team='PSG',
away_team='OM',
home_energy=50.0,
away_energy=50.0,
actual_winner='draw'
)
assert result['prediction']['predicted_winner'] == 'draw'
assert result['correct'] is True
class TestRunBacktestingBatch:
"""Tests for run_backtesting_batch function."""
def test_batch_all_correct(self):
"""Test backtesting with all correct predictions."""
matches = [
{
'match_id': 1,
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home',
'league': 'Ligue 1'
},
{
'match_id': 2,
'home_team': 'Lyon',
'away_team': 'Monaco',
'home_energy': 45.0,
'away_energy': 65.0,
'actual_winner': 'away',
'league': 'Ligue 1'
}
]
result = run_backtesting_batch(matches)
assert result['total_matches'] == 2
assert result['correct_predictions'] == 2
assert result['incorrect_predictions'] == 0
assert result['accuracy'] == 100.0
assert result['status'] == 'VALIDATED'
assert len(result['results']) == 2
def test_batch_mixed_results(self):
"""Test backtesting with mixed correct/incorrect predictions."""
matches = [
{
'match_id': 1,
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home',
'league': 'Ligue 1'
},
{
'match_id': 2,
'home_team': 'Lyon',
'away_team': 'Monaco',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'away',
'league': 'Ligue 1'
}
]
result = run_backtesting_batch(matches)
assert result['total_matches'] == 2
assert result['correct_predictions'] == 1
assert result['incorrect_predictions'] == 1
assert result['accuracy'] == 50.0
assert result['status'] == 'REVISION_REQUIRED'
def test_batch_with_leagues(self):
"""Test backtracking with multiple leagues."""
matches = [
{
'match_id': 1,
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home',
'league': 'Ligue 1'
},
{
'match_id': 2,
'home_team': 'Man City',
'away_team': 'Liverpool',
'home_energy': 70.0,
'away_energy': 50.0,
'actual_winner': 'home',
'league': 'Premier League'
}
]
result = run_backtesting_batch(matches)
assert 'metrics_by_league' in result
assert 'Ligue 1' in result['metrics_by_league']
assert 'Premier League' in result['metrics_by_league']
assert result['metrics_by_league']['Ligue 1']['total'] == 1
assert result['metrics_by_league']['Premier League']['total'] == 1
def test_batch_empty(self):
"""Test backtracking with no matches."""
result = run_backtesting_batch([])
assert result['total_matches'] == 0
assert result['correct_predictions'] == 0
assert result['incorrect_predictions'] == 0
assert result['accuracy'] == 0.0
def test_batch_missing_required_field(self):
"""Test backtracking with missing required field raises error."""
matches = [
{
'match_id': 1,
'home_team': 'PSG',
# Missing 'away_team'
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home'
}
]
with pytest.raises(ValueError, match="missing required fields"):
run_backtesting_batch(matches)
def test_batch_with_dates(self):
"""Test backtracking with match dates."""
match_date = datetime(2025, 1, 15, 20, 0, 0)
matches = [
{
'match_id': 1,
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'actual_winner': 'home',
'league': 'Ligue 1',
'date': match_date
}
]
result = run_backtesting_batch(matches)
assert result['results'][0]['date'] == match_date.isoformat()
class TestExportFormats:
"""Tests for export functions."""
def test_export_to_json(self):
"""Test JSON export format."""
backtesting_result = {
'total_matches': 2,
'correct_predictions': 1,
'incorrect_predictions': 1,
'accuracy': 50.0,
'status': 'REVISION_REQUIRED',
'results': [],
'metrics_by_league': {},
'timestamp': '2026-01-17T10:00:00Z',
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
}
json_output = export_to_json(backtesting_result)
assert isinstance(json_output, str)
assert 'total_matches' in json_output
assert 'accuracy' in json_output
def test_export_to_csv(self):
"""Test CSV export format."""
backtesting_result = {
'total_matches': 1,
'correct_predictions': 1,
'incorrect_predictions': 0,
'accuracy': 100.0,
'status': 'VALIDATED',
'results': [
{
'match_id': 1,
'league': 'Ligue 1',
'date': '2026-01-15T20:00:00Z',
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'prediction': {'predicted_winner': 'home', 'confidence': 40.0},
'actual_winner': 'home',
'correct': True
}
],
'metrics_by_league': {},
'timestamp': '2026-01-17T10:00:00Z',
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
}
csv_output = export_to_csv(backtesting_result)
assert isinstance(csv_output, str)
assert 'match_id' in csv_output
assert 'PSG' in csv_output
assert 'OM' in csv_output
def test_export_to_html(self):
"""Test HTML export format."""
backtesting_result = {
'total_matches': 1,
'correct_predictions': 1,
'incorrect_predictions': 0,
'accuracy': 100.0,
'status': 'VALIDATED',
'results': [
{
'match_id': 1,
'league': 'Ligue 1',
'date': '2026-01-15T20:00:00Z',
'home_team': 'PSG',
'away_team': 'OM',
'home_energy': 65.0,
'away_energy': 45.0,
'prediction': {'predicted_winner': 'home', 'confidence': 40.0},
'actual_winner': 'home',
'correct': True
}
],
'metrics_by_league': {},
'timestamp': '2026-01-17T10:00:00Z',
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
}
html_output = export_to_html(backtesting_result)
assert isinstance(html_output, str)
assert '<html>' in html_output
assert '</html>' in html_output
assert 'Backtesting Report' in html_output
assert '100.0%' in html_output
assert 'VALIDATED' in html_output
class TestFilterMatchesByLeague:
"""Tests for filter_matches_by_league function."""
def test_filter_by_single_league(self):
"""Test filtering by a single league."""
matches = [
{'league': 'Ligue 1', 'match_id': 1},
{'league': 'Premier League', 'match_id': 2},
{'league': 'Ligue 1', 'match_id': 3}
]
filtered = filter_matches_by_league(matches, ['Ligue 1'])
assert len(filtered) == 2
assert all(m['league'] == 'Ligue 1' for m in filtered)
def test_filter_by_multiple_leagues(self):
"""Test filtering by multiple leagues."""
matches = [
{'league': 'Ligue 1', 'match_id': 1},
{'league': 'Premier League', 'match_id': 2},
{'league': 'La Liga', 'match_id': 3}
]
filtered = filter_matches_by_league(matches, ['Ligue 1', 'Premier League'])
assert len(filtered) == 2
assert filtered[0]['league'] == 'Ligue 1'
assert filtered[1]['league'] == 'Premier League'
def test_filter_no_leagues(self):
"""Test that empty leagues list returns all matches."""
matches = [
{'league': 'Ligue 1', 'match_id': 1},
{'league': 'Premier League', 'match_id': 2}
]
filtered = filter_matches_by_league(matches, [])
assert len(filtered) == 2
def test_filter_none_leagues(self):
"""Test that None leagues list returns all matches."""
matches = [
{'league': 'Ligue 1', 'match_id': 1},
{'league': 'Premier League', 'match_id': 2}
]
filtered = filter_matches_by_league(matches, None)
assert len(filtered) == 2
class TestFilterMatchesByPeriod:
"""Tests for filter_matches_by_period function."""
def test_filter_by_start_date(self):
"""Test filtering by start date."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'date': datetime(2025, 1, 20), 'match_id': 2},
{'date': datetime(2025, 1, 5), 'match_id': 3}
]
start_date = datetime(2025, 1, 15)
filtered = filter_matches_by_period(matches, start_date=start_date)
assert len(filtered) == 1
assert filtered[0]['match_id'] == 2
def test_filter_by_end_date(self):
"""Test filtering by end date."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'date': datetime(2025, 1, 20), 'match_id': 2},
{'date': datetime(2025, 1, 5), 'match_id': 3}
]
end_date = datetime(2025, 1, 15)
filtered = filter_matches_by_period(matches, end_date=end_date)
assert len(filtered) == 2
assert sorted([m['match_id'] for m in filtered]) == [1, 3]
def test_filter_by_date_range(self):
"""Test filtering by date range."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'date': datetime(2025, 1, 20), 'match_id': 2},
{'date': datetime(2025, 1, 15), 'match_id': 3},
{'date': datetime(2025, 1, 5), 'match_id': 4}
]
start_date = datetime(2025, 1, 10)
end_date = datetime(2025, 1, 15)
filtered = filter_matches_by_period(matches, start_date=start_date, end_date=end_date)
assert len(filtered) == 2
assert sorted([m['match_id'] for m in filtered]) == [1, 3]
def test_filter_no_dates(self):
"""Test that None dates return all matches."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'date': datetime(2025, 1, 20), 'match_id': 2}
]
filtered = filter_matches_by_period(matches, start_date=None, end_date=None)
assert len(filtered) == 2
def test_filter_no_date_field(self):
"""Test matches without date field are excluded when filtering."""
matches = [
{'date': datetime(2025, 1, 10), 'match_id': 1},
{'match_id': 2} # No date field
]
start_date = datetime(2025, 1, 1)
filtered = filter_matches_by_period(matches, start_date=start_date)
assert len(filtered) == 1
assert filtered[0]['match_id'] == 1

View File

@@ -0,0 +1,89 @@
"""Tests pour valider la structure du répertoire backend."""
import os
from pathlib import Path
def test_backend_directory_exists():
"""Teste que le répertoire backend existe."""
backend_path = Path(__file__).parent.parent
assert backend_path.exists(), "Le répertoire backend doit exister"
assert backend_path.is_dir(), "Le backend doit être un répertoire"
def test_app_directory_exists():
"""Teste que le répertoire app existe."""
app_path = Path(__file__).parent.parent / "app"
assert app_path.exists(), "Le répertoire app doit exister"
assert app_path.is_dir(), "L'app doit être un répertoire"
def test_app_init_exists():
"""Teste que le fichier __init__.py existe dans app."""
init_path = Path(__file__).parent.parent / "app" / "__init__.py"
assert init_path.exists(), "Le fichier app/__init__.py doit exister"
assert init_path.is_file(), "L'__init__.py doit être un fichier"
def test_models_directory_exists():
"""Teste que le répertoire models existe."""
models_path = Path(__file__).parent.parent / "app" / "models"
assert models_path.exists(), "Le répertoire models doit exister"
assert models_path.is_dir(), "Les models doivent être un répertoire"
def test_schemas_directory_exists():
"""Teste que le répertoire schemas existe."""
schemas_path = Path(__file__).parent.parent / "app" / "schemas"
assert schemas_path.exists(), "Le répertoire schemas doit exister"
assert schemas_path.is_dir(), "Les schemas doivent être un répertoire"
def test_api_directory_exists():
"""Teste que le répertoire api existe."""
api_path = Path(__file__).parent.parent / "app" / "api"
assert api_path.exists(), "Le répertoire api doit exister"
assert api_path.is_dir(), "L'api doit être un répertoire"
def test_models_init_exists():
"""Teste que le fichier __init__.py existe dans models."""
init_path = Path(__file__).parent.parent / "app" / "models" / "__init__.py"
assert init_path.exists(), "Le fichier models/__init__.py doit exister"
assert init_path.is_file(), "L'__init__.py des models doit être un fichier"
def test_schemas_init_exists():
"""Teste que le fichier __init__.py existe dans schemas."""
init_path = Path(__file__).parent.parent / "app" / "schemas" / "__init__.py"
assert init_path.exists(), "Le fichier schemas/__init__.py doit exister"
assert init_path.is_file(), "L'__init__.py des schemas doit être un fichier"
def test_api_init_exists():
"""Teste que le fichier __init__.py existe dans api."""
init_path = Path(__file__).parent.parent / "app" / "api" / "__init__.py"
assert init_path.exists(), "Le fichier api/__init__.py doit exister"
assert init_path.is_file(), "L'__init__.py de l'api doit être un fichier"
def test_main_py_exists():
"""Teste que le fichier main.py existe."""
main_path = Path(__file__).parent.parent / "app" / "main.py"
assert main_path.exists(), "Le fichier app/main.py doit exister"
assert main_path.is_file(), "Le main.py doit être un fichier"
def test_database_py_exists():
"""Teste que le fichier database.py existe."""
db_path = Path(__file__).parent.parent / "app" / "database.py"
assert db_path.exists(), "Le fichier app/database.py doit exister"
assert db_path.is_file(), "Le database.py doit être un fichier"
def test_fastapi_app_importable():
"""Teste que l'application FastAPI peut être importée."""
try:
from app.main import app
assert app is not None, "L'application FastAPI doit être importable"
except ImportError as e:
raise AssertionError(f"Impossible d'importer l'application FastAPI: {e}")

View File

@@ -0,0 +1,378 @@
"""
Tests for Energy Calculator Module.
This module tests the energy score calculation functionality including:
- Multi-source weighted calculation
- Temporal weighting
- Degraded mode (when sources are unavailable)
- Score normalization (0-100)
- Confidence calculation
"""
import pytest
from datetime import datetime, timedelta
from app.ml.energy_calculator import (
calculate_energy_score,
apply_temporal_weighting,
calculate_confidence,
normalize_score,
apply_source_weights,
adjust_weights_for_degraded_mode
)
class TestCalculateEnergyScore:
"""Test the main energy score calculation function."""
def test_calculate_energy_score_with_all_sources(self):
"""Test energy calculation with all three sources available."""
# Arrange
match_id = 1
team_id = 1
twitter_sentiments = [
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
reddit_sentiments = [
{'compound': 0.3, 'positive': 0.4, 'negative': 0.3, 'neutral': 0.3, 'sentiment': 'positive'}
]
rss_sentiments = [
{'compound': 0.4, 'positive': 0.5, 'negative': 0.2, 'neutral': 0.3, 'sentiment': 'positive'}
]
# Act
result = calculate_energy_score(
match_id=match_id,
team_id=team_id,
twitter_sentiments=twitter_sentiments,
reddit_sentiments=reddit_sentiments,
rss_sentiments=rss_sentiments
)
# Assert
assert 'score' in result
assert 'confidence' in result
assert 'sources_used' in result
assert 0 <= result['score'] <= 100
assert len(result['sources_used']) == 3
assert result['confidence'] > 0.6 # High confidence with all sources
def test_calculate_energy_score_with_twitter_only(self):
"""Test energy calculation with only Twitter source available (degraded mode)."""
# Arrange
match_id = 1
team_id = 1
twitter_sentiments = [
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
reddit_sentiments = []
rss_sentiments = []
# Act
result = calculate_energy_score(
match_id=match_id,
team_id=team_id,
twitter_sentiments=twitter_sentiments,
reddit_sentiments=reddit_sentiments,
rss_sentiments=rss_sentiments
)
# Assert
assert 'score' in result
assert 'confidence' in result
assert 'sources_used' in result
assert 0 <= result['score'] <= 100
assert len(result['sources_used']) == 1
assert 'twitter' in result['sources_used']
assert result['confidence'] < 0.6 # Lower confidence in degraded mode
def test_calculate_energy_score_no_sentiment_data(self):
"""Test energy calculation with no sentiment data."""
# Arrange
match_id = 1
team_id = 1
# Act
result = calculate_energy_score(
match_id=match_id,
team_id=team_id,
twitter_sentiments=[],
reddit_sentiments=[],
rss_sentiments=[]
)
# Assert
assert result['score'] == 0.0
assert result['confidence'] == 0.0
assert len(result['sources_used']) == 0
class TestApplySourceWeights:
"""Test source weight application and adjustment."""
def test_apply_source_weights_all_sources(self):
"""Test applying weights with all sources available."""
# Arrange
twitter_score = 50.0
reddit_score = 40.0
rss_score = 30.0
available_sources = ['twitter', 'reddit', 'rss']
# Act
weighted_score = apply_source_weights(
twitter_score=twitter_score,
reddit_score=reddit_score,
rss_score=rss_score,
available_sources=available_sources
)
# Assert
expected = (50.0 * 0.60) + (40.0 * 0.25) + (30.0 * 0.15)
assert weighted_score == expected
def test_apply_source_weights_twitter_only(self):
"""Test applying weights with only Twitter available."""
# Arrange
twitter_score = 50.0
reddit_score = 0.0
rss_score = 0.0
available_sources = ['twitter']
# Act
weighted_score = apply_source_weights(
twitter_score=twitter_score,
reddit_score=reddit_score,
rss_score=rss_score,
available_sources=available_sources
)
# Assert
assert weighted_score == 50.0 # 100% of weight goes to Twitter
def test_adjust_weights_for_degraded_mode(self):
"""Test weight adjustment in degraded mode."""
# Arrange
original_weights = {
'twitter': 0.60,
'reddit': 0.25,
'rss': 0.15
}
available_sources = ['twitter', 'reddit']
# Act
adjusted_weights = adjust_weights_for_degraded_mode(
original_weights=original_weights,
available_sources=available_sources
)
# Assert
assert len(adjusted_weights) == 2
assert 'twitter' in adjusted_weights
assert 'reddit' in adjusted_weights
assert 'rss' not in adjusted_weights
# Total should be 1.0
total_weight = sum(adjusted_weights.values())
assert abs(total_weight - 1.0) < 0.001
class TestApplyTemporalWeighting:
"""Test temporal weighting of sentiment scores."""
def test_temporal_weighting_recent_tweets(self):
"""Test temporal weighting with recent tweets (within 1 hour)."""
# Arrange
base_score = 50.0
now = datetime.utcnow()
tweets_with_timestamps = [
{
'compound': 0.5,
'created_at': now - timedelta(minutes=30) # 30 minutes ago
},
{
'compound': 0.6,
'created_at': now - timedelta(minutes=15) # 15 minutes ago
}
]
# Act
weighted_score = apply_temporal_weighting(
base_score=base_score,
tweets_with_timestamps=tweets_with_timestamps
)
# Assert
# Recent tweets should have higher weight (close to 1.0)
# Score should be close to original but weighted by recency
assert 0 <= weighted_score <= 100
def test_temporal_weighting_old_tweets(self):
"""Test temporal weighting with old tweets (24+ hours)."""
# Arrange
base_score = 50.0
now = datetime.utcnow()
tweets_with_timestamps = [
{
'compound': 0.5,
'created_at': now - timedelta(hours=30) # 30 hours ago
},
{
'compound': 0.6,
'created_at': now - timedelta(hours=25) # 25 hours ago
}
]
# Act
weighted_score = apply_temporal_weighting(
base_score=base_score,
tweets_with_timestamps=tweets_with_timestamps
)
# Assert
# Old tweets should have lower weight (around 0.5)
assert 0 <= weighted_score <= 100
class TestNormalizeScore:
"""Test score normalization to 0-100 range."""
def test_normalize_score_negative(self):
"""Test normalization of negative scores."""
# Act
normalized = normalize_score(-1.5)
# Assert
assert normalized == 0.0
def test_normalize_score_positive(self):
"""Test normalization of positive scores."""
# Act
normalized = normalize_score(150.0)
# Assert
assert normalized == 100.0
def test_normalize_score_in_range(self):
"""Test normalization of scores within range."""
# Arrange & Act
normalized_0 = normalize_score(0.0)
normalized_50 = normalize_score(50.0)
normalized_100 = normalize_score(100.0)
# Assert
assert normalized_0 == 0.0
assert normalized_50 == 50.0
assert normalized_100 == 100.0
class TestCalculateConfidence:
"""Test confidence level calculation."""
def test_confidence_all_sources(self):
"""Test confidence with all sources available."""
# Arrange
available_sources = ['twitter', 'reddit', 'rss']
total_weight = 1.0
# Act
confidence = calculate_confidence(
available_sources=available_sources,
total_weight=total_weight
)
# Assert
assert confidence > 0.6 # High confidence
assert confidence <= 1.0
def test_confidence_single_source(self):
"""Test confidence with only one source available."""
# Arrange
available_sources = ['twitter']
total_weight = 0.6
# Act
confidence = calculate_confidence(
available_sources=available_sources,
total_weight=total_weight
)
# Assert
assert confidence < 0.6 # Lower confidence
assert confidence >= 0.3
def test_confidence_no_sources(self):
"""Test confidence with no sources available."""
# Arrange
available_sources = []
total_weight = 0.0
# Act
confidence = calculate_confidence(
available_sources=available_sources,
total_weight=total_weight
)
# Assert
assert confidence == 0.0
class TestEnergyFormula:
"""Test the core energy formula: Score = (Positive - Negative) × Volume × Virality."""
def test_formula_positive_sentiment(self):
"""Test energy formula with positive sentiment."""
# Arrange
positive_ratio = 0.7
negative_ratio = 0.2
volume = 100
virality = 1.5
# Act
score = (positive_ratio - negative_ratio) * volume * virality
# Assert
assert score > 0 # Positive energy
def test_formula_negative_sentiment(self):
"""Test energy formula with negative sentiment."""
# Arrange
positive_ratio = 0.2
negative_ratio = 0.7
volume = 100
virality = 1.5
# Act
score = (positive_ratio - negative_ratio) * volume * virality
# Assert
assert score < 0 # Negative energy
def test_formula_neutral_sentiment(self):
"""Test energy formula with neutral sentiment."""
# Arrange
positive_ratio = 0.5
negative_ratio = 0.5
volume = 100
virality = 1.5
# Act
score = (positive_ratio - negative_ratio) * volume * virality
# Assert
assert score == 0.0 # Neutral energy
def test_formula_zero_volume(self):
"""Test energy formula with zero volume."""
# Arrange
positive_ratio = 0.7
negative_ratio = 0.2
volume = 0
virality = 1.5
# Act
score = (positive_ratio - negative_ratio) * volume * virality
# Assert
assert score == 0.0 # No volume means no energy
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,282 @@
"""
Manual test script for energy calculator.
This script runs basic tests to verify the energy calculator implementation.
"""
from datetime import datetime, timedelta
from app.ml.energy_calculator import (
calculate_energy_score,
apply_source_weights,
adjust_weights_for_degraded_mode,
apply_temporal_weighting,
normalize_score,
calculate_confidence
)
def test_apply_source_weights():
"""Test source weight application."""
print("\n=== Test: Apply Source Weights ===")
# Test with all sources
weighted_score = apply_source_weights(
twitter_score=50.0,
reddit_score=40.0,
rss_score=30.0,
available_sources=['twitter', 'reddit', 'rss']
)
expected = (50.0 * 0.60) + (40.0 * 0.25) + (30.0 * 0.15)
assert weighted_score == expected, f"Expected {expected}, got {weighted_score}"
print(f"✓ All sources: {weighted_score} (expected: {expected})")
# Test with Twitter only
weighted_score = apply_source_weights(
twitter_score=50.0,
reddit_score=0.0,
rss_score=0.0,
available_sources=['twitter']
)
assert weighted_score == 50.0, f"Expected 50.0, got {weighted_score}"
print(f"✓ Twitter only: {weighted_score}")
def test_adjust_weights_for_degraded_mode():
"""Test degraded mode weight adjustment."""
print("\n=== Test: Adjust Weights for Degraded Mode ===")
original_weights = {
'twitter': 0.60,
'reddit': 0.25,
'rss': 0.15
}
# Test with Twitter and Reddit only
adjusted = adjust_weights_for_degraded_mode(
original_weights=original_weights,
available_sources=['twitter', 'reddit']
)
total_weight = sum(adjusted.values())
assert abs(total_weight - 1.0) < 0.001, f"Total weight should be 1.0, got {total_weight}"
print(f"✓ Twitter+Reddit weights: {adjusted} (total: {total_weight})")
# Test with only Twitter
adjusted = adjust_weights_for_degraded_mode(
original_weights=original_weights,
available_sources=['twitter']
)
total_weight = sum(adjusted.values())
assert adjusted['twitter'] == 1.0, f"Twitter weight should be 1.0, got {adjusted['twitter']}"
print(f"✓ Twitter only: {adjusted}")
def test_normalize_score():
"""Test score normalization."""
print("\n=== Test: Normalize Score ===")
# Test negative score
normalized = normalize_score(-10.0)
assert normalized == 0.0, f"Expected 0.0, got {normalized}"
print(f"✓ Negative score: -10.0 → {normalized}")
# Test score above 100
normalized = normalize_score(150.0)
assert normalized == 100.0, f"Expected 100.0, got {normalized}"
print(f"✓ Score above 100: 150.0 → {normalized}")
# Test score in range
normalized = normalize_score(50.0)
assert normalized == 50.0, f"Expected 50.0, got {normalized}"
print(f"✓ Score in range: 50.0 → {normalized}")
def test_calculate_confidence():
"""Test confidence calculation."""
print("\n=== Test: Calculate Confidence ===")
# Test with all sources
confidence = calculate_confidence(
available_sources=['twitter', 'reddit', 'rss'],
total_weight=1.0
)
assert confidence > 0.6, f"Confidence should be > 0.6, got {confidence}"
print(f"✓ All sources: {confidence}")
# Test with single source (Twitter)
confidence = calculate_confidence(
available_sources=['twitter'],
total_weight=0.6
)
assert confidence == 0.6, f"Expected 0.6, got {confidence}"
print(f"✓ Twitter only: {confidence}")
# Test with no sources
confidence = calculate_confidence(
available_sources=[],
total_weight=0.0
)
assert confidence == 0.0, f"Expected 0.0, got {confidence}"
print(f"✓ No sources: {confidence}")
def test_apply_temporal_weighting():
"""Test temporal weighting."""
print("\n=== Test: Apply Temporal Weighting ===")
base_score = 50.0
now = datetime.utcnow()
# Test with recent tweets (within 1 hour)
recent_tweets = [
{
'compound': 0.5,
'created_at': now - timedelta(minutes=30)
},
{
'compound': 0.6,
'created_at': now - timedelta(minutes=15)
}
]
weighted_score = apply_temporal_weighting(
base_score=base_score,
tweets_with_timestamps=recent_tweets
)
assert 0 <= weighted_score <= 100, f"Score should be between 0 and 100, got {weighted_score}"
print(f"✓ Recent tweets: {base_score}{weighted_score}")
# Test with old tweets (24+ hours)
old_tweets = [
{
'compound': 0.5,
'created_at': now - timedelta(hours=30)
},
{
'compound': 0.6,
'created_at': now - timedelta(hours=25)
}
]
weighted_score = apply_temporal_weighting(
base_score=base_score,
tweets_with_timestamps=old_tweets
)
assert 0 <= weighted_score <= 100, f"Score should be between 0 and 100, got {weighted_score}"
print(f"✓ Old tweets: {base_score}{weighted_score}")
def test_calculate_energy_score_all_sources():
"""Test energy score calculation with all sources."""
print("\n=== Test: Calculate Energy Score (All Sources) ===")
twitter_sentiments = [
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'},
{'compound': 0.7, 'positive': 0.7, 'negative': 0.1, 'neutral': 0.2, 'sentiment': 'positive'}
]
reddit_sentiments = [
{'compound': 0.3, 'positive': 0.4, 'negative': 0.3, 'neutral': 0.3, 'sentiment': 'positive'}
]
rss_sentiments = [
{'compound': 0.4, 'positive': 0.5, 'negative': 0.2, 'neutral': 0.3, 'sentiment': 'positive'}
]
result = calculate_energy_score(
match_id=1,
team_id=1,
twitter_sentiments=twitter_sentiments,
reddit_sentiments=reddit_sentiments,
rss_sentiments=rss_sentiments
)
assert 'score' in result, "Result should contain 'score'"
assert 'confidence' in result, "Result should contain 'confidence'"
assert 'sources_used' in result, "Result should contain 'sources_used'"
assert 0 <= result['score'] <= 100, f"Score should be between 0 and 100, got {result['score']}"
assert len(result['sources_used']) == 3, f"Should use 3 sources, got {len(result['sources_used'])}"
assert result['confidence'] > 0.6, f"Confidence should be > 0.6, got {result['confidence']}"
print(f"✓ Score: {result['score']}")
print(f"✓ Confidence: {result['confidence']}")
print(f"✓ Sources used: {result['sources_used']}")
def test_calculate_energy_score_degraded_mode():
"""Test energy score calculation in degraded mode."""
print("\n=== Test: Calculate Energy Score (Degraded Mode) ===")
twitter_sentiments = [
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
result = calculate_energy_score(
match_id=1,
team_id=1,
twitter_sentiments=twitter_sentiments,
reddit_sentiments=[],
rss_sentiments=[]
)
assert result['score'] >= 0, f"Score should be >= 0, got {result['score']}"
assert result['confidence'] < 0.6, f"Confidence should be < 0.6 in degraded mode, got {result['confidence']}"
assert len(result['sources_used']) == 1, f"Should use 1 source, got {len(result['sources_used'])}"
assert 'twitter' in result['sources_used'], "Should use Twitter"
print(f"✓ Score: {result['score']}")
print(f"✓ Confidence: {result['confidence']} (lower in degraded mode)")
print(f"✓ Sources used: {result['sources_used']}")
def test_calculate_energy_score_no_data():
"""Test energy score calculation with no data."""
print("\n=== Test: Calculate Energy Score (No Data) ===")
result = calculate_energy_score(
match_id=1,
team_id=1,
twitter_sentiments=[],
reddit_sentiments=[],
rss_sentiments=[]
)
assert result['score'] == 0.0, f"Score should be 0.0, got {result['score']}"
assert result['confidence'] == 0.0, f"Confidence should be 0.0, got {result['confidence']}"
assert len(result['sources_used']) == 0, f"Should use 0 sources, got {len(result['sources_used'])}"
print(f"✓ Score: {result['score']}")
print(f"✓ Confidence: {result['confidence']}")
print(f"✓ Sources used: {result['sources_used']}")
def main():
"""Run all manual tests."""
print("=" * 60)
print("MANUAL TESTS FOR ENERGY CALCULATOR")
print("=" * 60)
try:
test_apply_source_weights()
test_adjust_weights_for_degraded_mode()
test_normalize_score()
test_calculate_confidence()
test_apply_temporal_weighting()
test_calculate_energy_score_all_sources()
test_calculate_energy_score_degraded_mode()
test_calculate_energy_score_no_data()
print("\n" + "=" * 60)
print("✅ ALL TESTS PASSED!")
print("=" * 60)
return 0
except AssertionError as e:
print(f"\n❌ TEST FAILED: {e}")
return 1
except Exception as e:
print(f"\n❌ UNEXPECTED ERROR: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == '__main__':
exit(main())

View File

@@ -0,0 +1,394 @@
"""
Tests for Energy Service.
This module tests the energy service business logic including:
- Energy score calculation and storage
- Retrieval of energy scores
- Updating and deleting energy scores
- Filtering and querying
"""
import pytest
from datetime import datetime
from app.services.energy_service import (
calculate_and_store_energy_score,
get_energy_score,
get_energy_scores_by_match,
get_energy_scores_by_team,
get_energy_score_by_match_and_team,
update_energy_score,
delete_energy_score,
list_energy_scores
)
from app.schemas.energy_score import EnergyScoreCalculationRequest, EnergyScoreUpdate
class TestCalculateAndStoreEnergyScore:
"""Test energy score calculation and storage."""
def test_calculate_and_store_with_all_sources(self, db_session):
"""Test calculation and storage with all sources available."""
# Arrange
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
],
reddit_sentiments=[
{'compound': 0.3, 'positive': 0.4, 'negative': 0.3, 'neutral': 0.3, 'sentiment': 'positive'}
],
rss_sentiments=[
{'compound': 0.4, 'positive': 0.5, 'negative': 0.2, 'neutral': 0.3, 'sentiment': 'positive'}
]
)
# Act
energy_score = calculate_and_store_energy_score(db_session, request)
# Assert
assert energy_score is not None
assert energy_score.match_id == 1
assert energy_score.team_id == 1
assert 0 <= energy_score.score <= 100
assert energy_score.confidence > 0
assert len(energy_score.sources_used) == 3
assert energy_score.twitter_score is not None
assert energy_score.reddit_score is not None
assert energy_score.rss_score is not None
def test_calculate_and_store_with_twitter_only(self, db_session):
"""Test calculation and storage with only Twitter source."""
# Arrange
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
# Act
energy_score = calculate_and_store_energy_score(db_session, request)
# Assert
assert energy_score is not None
assert energy_score.match_id == 1
assert energy_score.team_id == 1
assert len(energy_score.sources_used) == 1
assert 'twitter' in energy_score.sources_used
assert energy_score.twitter_score is not None
assert energy_score.reddit_score is None
assert energy_score.rss_score is None
# Lower confidence in degraded mode
assert energy_score.confidence < 0.7
def test_calculate_and_store_no_sentiment_data(self, db_session):
"""Test calculation with no sentiment data."""
# Arrange
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=1,
twitter_sentiments=[],
reddit_sentiments=[],
rss_sentiments=[]
)
# Act
energy_score = calculate_and_store_energy_score(db_session, request)
# Assert
assert energy_score is not None
assert energy_score.score == 0.0
assert energy_score.confidence == 0.0
assert len(energy_score.sources_used) == 0
class TestGetEnergyScore:
"""Test retrieval of energy scores."""
def test_get_energy_score_by_id(self, db_session):
"""Test retrieving energy score by ID."""
# Arrange
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
created = calculate_and_store_energy_score(db_session, request)
# Act
retrieved = get_energy_score(db_session, created.id)
# Assert
assert retrieved is not None
assert retrieved.id == created.id
assert retrieved.match_id == created.match_id
assert retrieved.team_id == created.team_id
assert retrieved.score == created.score
def test_get_energy_score_not_found(self, db_session):
"""Test retrieving non-existent energy score."""
# Act
retrieved = get_energy_score(db_session, 99999)
# Assert
assert retrieved is None
class TestGetEnergyScoresByMatch:
"""Test retrieval of energy scores by match."""
def test_get_energy_scores_by_match_id(self, db_session):
"""Test retrieving all energy scores for a specific match."""
# Arrange
# Create energy scores for different matches
for match_id in [1, 1, 2]:
request = EnergyScoreCalculationRequest(
match_id=match_id,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
calculate_and_store_energy_score(db_session, request)
# Act
scores_for_match_1 = get_energy_scores_by_match(db_session, 1)
scores_for_match_2 = get_energy_scores_by_match(db_session, 2)
# Assert
assert len(scores_for_match_1) == 2
assert all(score.match_id == 1 for score in scores_for_match_1)
assert len(scores_for_match_2) == 1
assert scores_for_match_2[0].match_id == 2
class TestGetEnergyScoresByTeam:
"""Test retrieval of energy scores by team."""
def test_get_energy_scores_by_team_id(self, db_session):
"""Test retrieving all energy scores for a specific team."""
# Arrange
# Create energy scores for different teams
for team_id in [1, 1, 2]:
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=team_id,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
calculate_and_store_energy_score(db_session, request)
# Act
scores_for_team_1 = get_energy_scores_by_team(db_session, 1)
scores_for_team_2 = get_energy_scores_by_team(db_session, 2)
# Assert
assert len(scores_for_team_1) == 2
assert all(score.team_id == 1 for score in scores_for_team_1)
assert len(scores_for_team_2) == 1
assert scores_for_team_2[0].team_id == 2
class TestGetEnergyScoreByMatchAndTeam:
"""Test retrieval of most recent energy score by match and team."""
def test_get_energy_score_by_match_and_team(self, db_session):
"""Test retrieving most recent energy score for match and team."""
# Arrange
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
calculate_and_store_energy_score(db_session, request)
# Act
retrieved = get_energy_score_by_match_and_team(db_session, 1, 1)
# Assert
assert retrieved is not None
assert retrieved.match_id == 1
assert retrieved.team_id == 1
def test_get_energy_score_by_match_and_team_not_found(self, db_session):
"""Test retrieving non-existent energy score."""
# Act
retrieved = get_energy_score_by_match_and_team(db_session, 999, 999)
# Assert
assert retrieved is None
class TestUpdateEnergyScore:
"""Test updating energy scores."""
def test_update_energy_score(self, db_session):
"""Test updating an existing energy score."""
# Arrange
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
created = calculate_and_store_energy_score(db_session, request)
# Act
update = EnergyScoreUpdate(
score=75.0,
confidence=0.9
)
updated = update_energy_score(db_session, created.id, update)
# Assert
assert updated is not None
assert updated.id == created.id
assert updated.score == 75.0
assert updated.confidence == 0.9
assert updated.updated_at > created.updated_at
def test_update_energy_score_not_found(self, db_session):
"""Test updating non-existent energy score."""
# Arrange
update = EnergyScoreUpdate(score=75.0, confidence=0.9)
# Act
updated = update_energy_score(db_session, 99999, update)
# Assert
assert updated is None
class TestDeleteEnergyScore:
"""Test deleting energy scores."""
def test_delete_energy_score(self, db_session):
"""Test deleting an existing energy score."""
# Arrange
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
created = calculate_and_store_energy_score(db_session, request)
# Act
deleted = delete_energy_score(db_session, created.id)
# Assert
assert deleted is True
retrieved = get_energy_score(db_session, created.id)
assert retrieved is None
def test_delete_energy_score_not_found(self, db_session):
"""Test deleting non-existent energy score."""
# Act
deleted = delete_energy_score(db_session, 99999)
# Assert
assert deleted is False
class TestListEnergyScores:
"""Test listing and filtering energy scores."""
def test_list_energy_scores_default(self, db_session):
"""Test listing energy scores with default parameters."""
# Arrange
for i in range(5):
request = EnergyScoreCalculationRequest(
match_id=i,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
calculate_and_store_energy_score(db_session, request)
# Act
scores = list_energy_scores(db_session)
# Assert
assert len(scores) == 5
def test_list_energy_scores_with_match_filter(self, db_session):
"""Test listing energy scores filtered by match ID."""
# Arrange
for match_id in [1, 1, 2]:
request = EnergyScoreCalculationRequest(
match_id=match_id,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
calculate_and_store_energy_score(db_session, request)
# Act
scores = list_energy_scores(db_session, match_id=1)
# Assert
assert len(scores) == 2
assert all(score.match_id == 1 for score in scores)
def test_list_energy_scores_with_score_filter(self, db_session):
"""Test listing energy scores filtered by score range."""
# Arrange
for score in [30.0, 50.0, 70.0, 90.0]:
request = EnergyScoreCalculationRequest(
match_id=1,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
created = calculate_and_store_energy_score(db_session, request)
# Manually update score for testing
created.score = score
db_session.commit()
# Act
scores = list_energy_scores(db_session, min_score=40.0, max_score=80.0)
# Assert
assert len(scores) == 2
assert all(40.0 <= score.score <= 80.0 for score in scores)
def test_list_energy_scores_with_pagination(self, db_session):
"""Test listing energy scores with pagination."""
# Arrange
for i in range(10):
request = EnergyScoreCalculationRequest(
match_id=i,
team_id=1,
twitter_sentiments=[
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
]
)
calculate_and_store_energy_score(db_session, request)
# Act
page1 = list_energy_scores(db_session, limit=5, offset=0)
page2 = list_energy_scores(db_session, limit=5, offset=5)
# Assert
assert len(page1) == 5
assert len(page2) == 5
# Verify no overlap
page1_ids = [score.id for score in page1]
page2_ids = [score.id for score in page2]
assert set(page1_ids).isdisjoint(set(page2_ids))
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,281 @@
"""
Tests for energy calculation worker.
"""
import pytest
from unittest.mock import Mock, patch
from sqlalchemy.orm import Session
from app.workers.energy_worker import (
EnergyWorker,
create_energy_worker
)
class TestEnergyWorker:
"""Tests for EnergyWorker class."""
def test_initialization(self):
"""Test energy worker initialization."""
worker = EnergyWorker()
# No specific initialization required for energy worker
assert worker is not None
@patch('app.workers.energy_worker.calculate_and_store_energy_score')
@patch('app.workers.energy_worker.get_energy_score_by_match_and_team')
def test_execute_energy_calculation_task_new(
self,
mock_get_score,
mock_calculate_store
):
"""Test executing energy calculation for new score."""
# Create worker
worker = EnergyWorker()
# Mock database session
mock_db = Mock(spec=Session)
# Mock no existing score
mock_get_score.return_value = None
# Mock calculation
mock_energy_score = Mock()
mock_energy_score.score = 75.5
mock_energy_score.confidence = 0.82
mock_energy_score.sources_used = ['twitter', 'reddit']
mock_energy_score.twitter_score = 0.6
mock_energy_score.reddit_score = 0.7
mock_energy_score.rss_score = 0.5
mock_energy_score.temporal_factor = 1.0
mock_calculate_store.return_value = mock_energy_score
# Execute task
task = {
'match_id': 123,
'team_id': 456,
'twitter_sentiments': [
{'compound': 0.6, 'sentiment': 'positive'},
{'compound': 0.7, 'sentiment': 'positive'}
],
'reddit_sentiments': [
{'compound': 0.7, 'sentiment': 'positive'},
{'compound': 0.8, 'sentiment': 'positive'}
],
'rss_sentiments': [
{'compound': 0.5, 'sentiment': 'positive'}
],
'tweets_with_timestamps': []
}
result = worker.execute_energy_calculation_task(task, mock_db)
# Verify calculation called
mock_calculate_store.assert_called_once()
# Verify result
assert result['energy_score'] == 75.5
assert result['confidence'] == 0.82
assert result['sources_used'] == ['twitter', 'reddit']
assert result['status'] == 'success'
assert result['metadata']['match_id'] == 123
assert result['metadata']['team_id'] == 456
assert result['metadata']['twitter_score'] == 0.6
assert result['metadata']['reddit_score'] == 0.7
assert result['metadata']['rss_score'] == 0.5
assert result['metadata']['temporal_factor'] == 1.0
@patch('app.workers.energy_worker.get_energy_score_by_match_and_team')
def test_execute_energy_calculation_task_existing(
self,
mock_get_score
):
"""Test executing energy calculation for existing score."""
# Create worker
worker = EnergyWorker()
# Mock database session
mock_db = Mock(spec=Session)
# Mock existing score
mock_existing_score = Mock()
mock_existing_score.score = 72.0
mock_existing_score.confidence = 0.78
mock_existing_score.sources_used = ['twitter', 'reddit']
mock_get_score.return_value = mock_existing_score
# Execute task
task = {
'match_id': 123,
'team_id': 456,
'twitter_sentiments': [],
'reddit_sentiments': [],
'rss_sentiments': [],
'tweets_with_timestamps': []
}
result = worker.execute_energy_calculation_task(task, mock_db)
# Verify result from existing score
assert result['energy_score'] == 72.0
assert result['confidence'] == 0.78
assert result['sources_used'] == ['twitter', 'reddit']
assert result['status'] == 'success'
assert result['metadata']['updated_existing'] is True
@patch('app.workers.energy_worker.calculate_and_store_energy_score')
@patch('app.workers.energy_worker.get_energy_score_by_match_and_team')
def test_execute_energy_calculation_task_error_handling(
self,
mock_get_score,
mock_calculate_store
):
"""Test error handling in energy calculation."""
worker = EnergyWorker()
mock_db = Mock(spec=Session)
# Mock no existing score
mock_get_score.return_value = None
# Mock calculation error
mock_calculate_store.side_effect = Exception("Calculation error")
# Execute task
task = {
'match_id': 123,
'team_id': 456,
'twitter_sentiments': [],
'reddit_sentiments': [],
'rss_sentiments': [],
'tweets_with_timestamps': []
}
result = worker.execute_energy_calculation_task(task, mock_db)
# Verify error handling
assert result['energy_score'] == 0.0
assert result['confidence'] == 0.0
assert result['sources_used'] == []
assert result['status'] == 'error'
assert 'error' in result
@patch('app.workers.energy_worker.calculate_energy_score')
def test_calculate_mock_energy(self, mock_calculate):
"""Test mock energy calculation (for testing)."""
worker = EnergyWorker()
# Mock calculation
mock_calculate.return_value = {
'score': 75.5,
'confidence': 0.82,
'sources_used': ['twitter', 'reddit']
}
# Calculate mock energy
twitter_sentiments = [
{'compound': 0.6, 'sentiment': 'positive'}
]
reddit_sentiments = [
{'compound': 0.7, 'sentiment': 'positive'}
]
result = worker.calculate_mock_energy(
twitter_sentiments=twitter_sentiments,
reddit_sentiments=reddit_sentiments
)
# Verify calculation called
mock_calculate.assert_called_once_with(
match_id=0,
team_id=0,
twitter_sentiments=twitter_sentiments,
reddit_sentiments=reddit_sentiments,
rss_sentiments=[],
tweets_with_timestamps=[]
)
# Verify result
assert result['energy_score'] == 75.5
assert result['confidence'] == 0.82
assert result['sources_used'] == ['twitter', 'reddit']
assert result['status'] == 'success'
@patch('app.workers.energy_worker.calculate_energy_score')
def test_calculate_mock_energy_error(self, mock_calculate):
"""Test error handling in mock energy calculation."""
worker = EnergyWorker()
# Mock calculation error
mock_calculate.side_effect = Exception("Calculation error")
# Calculate mock energy
result = worker.calculate_mock_energy(
twitter_sentiments=[],
reddit_sentiments=[]
)
# Verify error handling
assert result['energy_score'] == 0.0
assert result['confidence'] == 0.0
assert result['sources_used'] == []
assert result['status'] == 'error'
assert 'error' in result
@patch('app.workers.energy_worker.calculate_energy_score')
def test_calculate_mock_energy_with_all_sources(self, mock_calculate):
"""Test mock energy calculation with all sources."""
worker = EnergyWorker()
# Mock calculation
mock_calculate.return_value = {
'score': 80.0,
'confidence': 0.85,
'sources_used': ['twitter', 'reddit', 'rss']
}
# Calculate mock energy with all sources
twitter_sentiments = [
{'compound': 0.6, 'sentiment': 'positive'}
]
reddit_sentiments = [
{'compound': 0.7, 'sentiment': 'positive'}
]
rss_sentiments = [
{'compound': 0.5, 'sentiment': 'positive'}
]
tweets_with_timestamps = [
{'timestamp': '2026-01-17T10:00:00', 'text': 'Great match!'}
]
result = worker.calculate_mock_energy(
twitter_sentiments=twitter_sentiments,
reddit_sentiments=reddit_sentiments,
rss_sentiments=rss_sentiments,
tweets_with_timestamps=tweets_with_timestamps
)
# Verify calculation called with all sources
mock_calculate.assert_called_once_with(
match_id=0,
team_id=0,
twitter_sentiments=twitter_sentiments,
reddit_sentiments=reddit_sentiments,
rss_sentiments=rss_sentiments,
tweets_with_timestamps=tweets_with_timestamps
)
# Verify result
assert result['energy_score'] == 80.0
assert result['confidence'] == 0.85
assert result['sources_used'] == ['twitter', 'reddit', 'rss']
assert result['status'] == 'success'
class TestCreateEnergyWorker:
"""Tests for create_energy_worker factory function."""
def test_create_energy_worker(self):
"""Test creating an energy worker."""
worker = create_energy_worker()
assert isinstance(worker, EnergyWorker)

View File

@@ -0,0 +1,152 @@
"""
Unit tests for Match model.
"""
import pytest
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from app.models.match import Match
from app.models.prediction import Prediction
from app.database import Base, engine, SessionLocal
@pytest.fixture(scope="function")
def db_session():
"""Create a fresh database session for each test."""
# Create tables
Base.metadata.create_all(bind=engine)
# Create session
session = SessionLocal()
try:
yield session
session.rollback()
finally:
session.close()
Base.metadata.drop_all(bind=engine)
class TestMatchModel:
"""Test Match SQLAlchemy model."""
def test_match_creation(self, db_session: Session):
"""Test creating a match in database."""
match = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
assert match.id is not None
assert match.home_team == "PSG"
assert match.away_team == "Olympique de Marseille"
assert match.league == "Ligue 1"
assert match.status == "scheduled"
def test_match_required_fields(self, db_session: Session):
"""Test that all required fields must be provided."""
# Missing home_team
match = Match(
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
db_session.add(match)
with pytest.raises(Exception): # IntegrityError expected
db_session.commit()
def test_match_to_dict(self, db_session: Session):
"""Test converting match to dictionary."""
match = Match(
home_team="Barcelona",
away_team="Real Madrid",
date=datetime.now(timezone.utc),
league="La Liga",
status="in_progress"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
match_dict = match.to_dict()
assert match_dict['home_team'] == "Barcelona"
assert match_dict['away_team'] == "Real Madrid"
assert match_dict['league'] == "La Liga"
assert match_dict['status'] == "in_progress"
assert 'id' in match_dict
assert 'date' in match_dict
def test_match_repr(self, db_session: Session):
"""Test match __repr__ method."""
match = Match(
home_team="Manchester City",
away_team="Liverpool",
date=datetime.now(timezone.utc),
league="Premier League",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
repr_str = repr(match)
assert "Match" in repr_str
assert "id=" in repr_str
assert "Manchester City" in repr_str
assert "Liverpool" in repr_str
def test_match_relationships(self, db_session: Session):
"""Test match relationships with predictions."""
match = Match(
home_team="Juventus",
away_team="Inter Milan",
date=datetime.now(timezone.utc),
league="Serie A",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
# Create predictions
prediction1 = Prediction(
match_id=match.id,
energy_score="high",
confidence="80%",
predicted_winner="Juventus",
created_at=datetime.now(timezone.utc)
)
prediction2 = Prediction(
match_id=match.id,
energy_score="medium",
confidence="60%",
predicted_winner="Inter Milan",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction1)
db_session.add(prediction2)
db_session.commit()
# Refresh match to load relationships
db_session.refresh(match)
assert len(match.predictions) == 2
assert match.predictions[0].match_id == match.id
assert match.predictions[1].match_id == match.id

View File

@@ -0,0 +1,215 @@
"""
Unit tests for Match Pydantic schemas.
"""
import pytest
from datetime import datetime, timezone
from pydantic import ValidationError
from app.schemas.match import (
MatchBase,
MatchCreate,
MatchUpdate,
MatchResponse,
MatchListResponse,
MatchStatsResponse
)
class TestMatchBase:
"""Test MatchBase schema."""
def test_match_base_valid(self):
"""Test creating a valid MatchBase."""
match_data = {
"home_team": "PSG",
"away_team": "Olympique de Marseille",
"date": datetime.now(timezone.utc),
"league": "Ligue 1",
"status": "scheduled"
}
match = MatchBase(**match_data)
assert match.home_team == "PSG"
assert match.away_team == "Olympique de Marseille"
assert match.league == "Ligue 1"
assert match.status == "scheduled"
def test_match_base_home_team_too_long(self):
"""Test that home_team exceeds max length."""
with pytest.raises(ValidationError) as exc_info:
MatchBase(
home_team="A" * 256, # Too long
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
assert "at most 255 characters" in str(exc_info.value).lower()
def test_match_base_away_team_empty(self):
"""Test that away_team cannot be empty."""
with pytest.raises(ValidationError) as exc_info:
MatchBase(
home_team="PSG",
away_team="", # Empty
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
assert "at least 1 character" in str(exc_info.value).lower()
def test_match_base_status_too_long(self):
"""Test that status exceeds max length."""
with pytest.raises(ValidationError) as exc_info:
MatchBase(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="A" * 51 # Too long
)
assert "at most 50 characters" in str(exc_info.value).lower()
class TestMatchCreate:
"""Test MatchCreate schema."""
def test_match_create_valid(self):
"""Test creating a valid MatchCreate."""
match_data = {
"home_team": "Barcelona",
"away_team": "Real Madrid",
"date": datetime.now(timezone.utc),
"league": "La Liga",
"status": "in_progress"
}
match = MatchCreate(**match_data)
assert match.home_team == "Barcelona"
assert match.away_team == "Real Madrid"
assert match.status == "in_progress"
class TestMatchUpdate:
"""Test MatchUpdate schema."""
def test_match_update_partial(self):
"""Test updating only some fields."""
update_data = {
"status": "completed"
}
match_update = MatchUpdate(**update_data)
assert match_update.status == "completed"
assert match_update.home_team is None
assert match_update.away_team is None
def test_match_update_all_fields(self):
"""Test updating all fields."""
update_data = {
"home_team": "Manchester City",
"away_team": "Liverpool",
"date": datetime.now(timezone.utc),
"league": "Premier League",
"status": "in_progress"
}
match_update = MatchUpdate(**update_data)
assert match_update.home_team == "Manchester City"
assert match_update.away_team == "Liverpool"
assert match_update.league == "Premier League"
assert match_update.status == "in_progress"
def test_match_update_empty(self):
"""Test that MatchUpdate can be empty."""
match_update = MatchUpdate()
assert match_update.home_team is None
assert match_update.away_team is None
assert match_update.date is None
assert match_update.league is None
assert match_update.status is None
class TestMatchResponse:
"""Test MatchResponse schema."""
def test_match_response_from_dict(self):
"""Test creating MatchResponse from dictionary."""
match_dict = {
"id": 1,
"home_team": "Juventus",
"away_team": "Inter Milan",
"date": datetime.now(timezone.utc),
"league": "Serie A",
"status": "scheduled"
}
match = MatchResponse(**match_dict)
assert match.id == 1
assert match.home_team == "Juventus"
assert match.away_team == "Inter Milan"
assert match.league == "Serie A"
assert match.status == "scheduled"
class TestMatchListResponse:
"""Test MatchListResponse schema."""
def test_match_list_response(self):
"""Test creating a MatchListResponse."""
matches_data = [
{
"id": 1,
"home_team": "PSG",
"away_team": "Olympique de Marseille",
"date": datetime.now(timezone.utc),
"league": "Ligue 1",
"status": "scheduled"
},
{
"id": 2,
"home_team": "Barcelona",
"away_team": "Real Madrid",
"date": datetime.now(timezone.utc),
"league": "La Liga",
"status": "scheduled"
}
]
response = MatchListResponse(data=matches_data, count=2, meta={"page": 1})
assert response.count == 2
assert len(response.data) == 2
assert response.meta["page"] == 1
class TestMatchStatsResponse:
"""Test MatchStatsResponse schema."""
def test_match_stats_response(self):
"""Test creating a MatchStatsResponse."""
stats = {
"total_matches": 10,
"matches_by_league": {"Ligue 1": 5, "La Liga": 5},
"matches_by_status": {"scheduled": 5, "completed": 5},
"upcoming_matches": 5,
"completed_matches": 5
}
match_stats = MatchStatsResponse(**stats)
assert match_stats.total_matches == 10
assert match_stats.matches_by_league["Ligue 1"] == 5
assert match_stats.matches_by_status["scheduled"] == 5
assert match_stats.upcoming_matches == 5
assert match_stats.completed_matches == 5

View File

@@ -0,0 +1,538 @@
"""
Unit tests for Prediction API endpoints.
"""
import pytest
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from app.database import Base, engine, SessionLocal
from app.main import app
from app.models.match import Match
from app.models.prediction import Prediction
@pytest.fixture(scope="function")
def db_session():
"""Create a fresh database session for each test."""
Base.metadata.create_all(bind=engine)
session = SessionLocal()
try:
yield session
session.rollback()
finally:
session.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture
def client():
"""Create a test client."""
return TestClient(app)
@pytest.fixture
def sample_match(db_session: Session) -> Match:
"""Create a sample match for testing."""
match = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
return match
@pytest.fixture
def sample_prediction(db_session: Session, sample_match: Match) -> Prediction:
"""Create a sample prediction for testing."""
prediction = Prediction(
match_id=sample_match.id,
energy_score="high",
confidence="75.0%",
predicted_winner="PSG",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction)
db_session.commit()
db_session.refresh(prediction)
return prediction
class TestCreatePredictionEndpoint:
"""Test POST /api/v1/predictions/matches/{match_id}/predict endpoint."""
def test_create_prediction_success(self, client: TestClient, sample_match: Match):
"""Test creating a prediction successfully."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 65.0,
"away_energy": 45.0,
"energy_score_label": "high"
}
)
assert response.status_code == 201
data = response.json()
assert data["match_id"] == sample_match.id
assert data["energy_score"] == "high"
assert "%" in data["confidence"]
assert data["predicted_winner"] == "PSG"
assert "id" in data
assert "created_at" in data
def test_create_prediction_default_energy_label(self, client: TestClient, sample_match: Match):
"""Test creating prediction with auto-generated energy label."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 75.0,
"away_energy": 65.0
}
)
assert response.status_code == 201
data = response.json()
assert data["energy_score"] in ["high", "very_high", "medium", "low"]
def test_create_prediction_home_win(self, client: TestClient, sample_match: Match):
"""Test prediction when home team wins."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 70.0,
"away_energy": 30.0
}
)
assert response.status_code == 201
assert response.json()["predicted_winner"] == "PSG"
def test_create_prediction_away_win(self, client: TestClient, sample_match: Match):
"""Test prediction when away team wins."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 30.0,
"away_energy": 70.0
}
)
assert response.status_code == 201
assert response.json()["predicted_winner"] == "Olympique de Marseille"
def test_create_prediction_draw(self, client: TestClient, sample_match: Match):
"""Test prediction for draw."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": 50.0,
"away_energy": 50.0
}
)
assert response.status_code == 201
assert response.json()["predicted_winner"] == "Draw"
assert response.json()["confidence"] == "0.0%"
def test_create_prediction_nonexistent_match(self, client: TestClient):
"""Test creating prediction for non-existent match."""
response = client.post(
"/api/v1/predictions/matches/999/predict",
params={
"home_energy": 65.0,
"away_energy": 45.0
}
)
assert response.status_code == 400
assert "not found" in response.json()["detail"].lower()
def test_create_prediction_negative_energy(self, client: TestClient, sample_match: Match):
"""Test creating prediction with negative energy."""
response = client.post(
f"/api/v1/predictions/matches/{sample_match.id}/predict",
params={
"home_energy": -10.0,
"away_energy": 45.0
}
)
assert response.status_code == 400
class TestGetPredictionEndpoint:
"""Test GET /api/v1/predictions/{prediction_id} endpoint."""
def test_get_prediction_success(self, client: TestClient, sample_prediction: Prediction):
"""Test getting a prediction by ID successfully."""
response = client.get(f"/api/v1/predictions/{sample_prediction.id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == sample_prediction.id
assert data["match_id"] == sample_prediction.match_id
assert data["energy_score"] == sample_prediction.energy_score
assert data["confidence"] == sample_prediction.confidence
assert data["predicted_winner"] == sample_prediction.predicted_winner
def test_get_prediction_not_found(self, client: TestClient):
"""Test getting non-existent prediction."""
response = client.get("/api/v1/predictions/999")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
class TestGetPredictionsForMatchEndpoint:
"""Test GET /api/v1/predictions/matches/{match_id} endpoint."""
def test_get_predictions_for_match_success(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting all predictions for a match."""
# Create multiple predictions
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
for i in range(3):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0 + i * 10,
away_energy=40.0 + i * 5
)
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "count" in data
assert data["count"] == 3
assert len(data["data"]) == 3
assert all(p["match_id"] == sample_match.id for p in data["data"])
def test_get_predictions_for_empty_match(self, client: TestClient, sample_match: Match):
"""Test getting predictions for match with no predictions."""
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}")
assert response.status_code == 200
data = response.json()
assert data["count"] == 0
assert len(data["data"]) == 0
class TestGetLatestPredictionForMatchEndpoint:
"""Test GET /api/v1/predictions/matches/{match_id}/latest endpoint."""
def test_get_latest_prediction_success(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting latest prediction for a match."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create multiple predictions
prediction1 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0,
away_energy=40.0
)
import time
time.sleep(0.01)
prediction2 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=60.0,
away_energy=35.0
)
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest")
assert response.status_code == 200
data = response.json()
assert data["id"] == prediction2.id
assert data["match_id"] == sample_match.id
def test_get_latest_prediction_not_found(self, client: TestClient, sample_match: Match):
"""Test getting latest prediction for match with no predictions."""
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest")
assert response.status_code == 404
assert "No predictions found" in response.json()["detail"]
class TestDeletePredictionEndpoint:
"""Test DELETE /api/v1/predictions/{prediction_id} endpoint."""
def test_delete_prediction_success(self, client: TestClient, sample_prediction: Prediction):
"""Test deleting a prediction successfully."""
response = client.delete(f"/api/v1/predictions/{sample_prediction.id}")
assert response.status_code == 204
# Verify prediction is deleted
get_response = client.get(f"/api/v1/predictions/{sample_prediction.id}")
assert get_response.status_code == 404
def test_delete_prediction_not_found(self, client: TestClient):
"""Test deleting non-existent prediction."""
response = client.delete("/api/v1/predictions/999")
assert response.status_code == 404
class TestGetPredictionsEndpoint:
"""Test GET /api/v1/predictions endpoint with pagination and filters."""
def test_get_predictions_default(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting predictions with default parameters."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create multiple predictions
for i in range(5):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0 + i * 5,
away_energy=40.0 + i * 3
)
response = client.get("/api/v1/predictions")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "meta" in data
assert "total" in data["meta"]
assert "limit" in data["meta"]
assert "offset" in data["meta"]
assert "timestamp" in data["meta"]
assert data["meta"]["version"] == "v1"
assert data["meta"]["limit"] == 20
assert data["meta"]["offset"] == 0
def test_get_predictions_with_pagination(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting predictions with custom pagination."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create 25 predictions
for i in range(25):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0 + i,
away_energy=40.0 + i
)
# Get first page with limit 10
response = client.get("/api/v1/predictions?limit=10&offset=0")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 10
assert data["meta"]["limit"] == 10
assert data["meta"]["offset"] == 0
assert data["meta"]["total"] == 25
# Get second page
response2 = client.get("/api/v1/predictions?limit=10&offset=10")
assert response2.status_code == 200
data2 = response2.json()
assert len(data2["data"]) == 10
assert data2["meta"]["offset"] == 10
def test_get_predictions_with_league_filter(self, client: TestClient, db_session: Session):
"""Test filtering predictions by league."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create matches in different leagues
match1 = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
match2 = Match(
home_team="Real Madrid",
away_team="Barcelona",
date=datetime.now(timezone.utc),
league="La Liga",
status="scheduled"
)
db_session.add_all([match1, match2])
db_session.commit()
db_session.refresh(match1)
db_session.refresh(match2)
# Create predictions for both matches
service.create_prediction_for_match(match1.id, 65.0, 45.0)
service.create_prediction_for_match(match2.id, 60.0, 50.0)
# Filter by Ligue 1
response = client.get("/api/v1/predictions?league=Ligue 1")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
assert data["data"][0]["match"]["league"] == "Ligue 1"
def test_get_predictions_with_date_filter(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test filtering predictions by date range."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create predictions for matches with different dates
future_date = datetime.now(timezone.utc).replace(day=20, month=2)
past_date = datetime.now(timezone.utc).replace(day=1, month=1, year=2025)
future_match = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=future_date,
league="Ligue 1",
status="scheduled"
)
past_match = Match(
home_team="Real Madrid",
away_team="Barcelona",
date=past_date,
league="La Liga",
status="scheduled"
)
db_session.add_all([future_match, past_match])
db_session.commit()
db_session.refresh(future_match)
db_session.refresh(past_match)
service.create_prediction_for_match(future_match.id, 65.0, 45.0)
service.create_prediction_for_match(past_match.id, 60.0, 50.0)
# Filter by date range
response = client.get(
f"/api/v1/predictions?date_min=2025-01-01T00:00:00Z"
)
assert response.status_code == 200
data = response.json()
# Should include future match but not past match
assert len(data["data"]) >= 1
assert any(pred["match"]["date"] >= "2025-01-01" for pred in data["data"])
def test_get_predictions_invalid_limit(self, client: TestClient):
"""Test getting predictions with invalid limit."""
response = client.get("/api/v1/predictions?limit=150")
# Pydantic should validate this before endpoint executes
assert response.status_code == 422
def test_get_predictions_empty(self, client: TestClient):
"""Test getting predictions when none exist."""
response = client.get("/api/v1/predictions")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 0
assert data["meta"]["total"] == 0
class TestGetPredictionByMatchIdEndpoint:
"""Test GET /api/v1/predictions/match/{match_id} endpoint."""
def test_get_prediction_by_match_id_success(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting prediction details by match ID."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create a prediction
prediction = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=65.0,
away_energy=45.0
)
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == prediction.id
assert data["match_id"] == sample_match.id
assert "match" in data
assert data["match"]["home_team"] == "PSG"
assert data["match"]["away_team"] == "Olympique de Marseille"
assert "energy_score" in data
assert "confidence" in data
assert "predicted_winner" in data
assert "history" in data
assert len(data["history"]) >= 1
def test_get_prediction_by_match_id_with_history(self, client: TestClient, db_session: Session, sample_match: Match):
"""Test getting prediction with history."""
from app.services.prediction_service import PredictionService
service = PredictionService(db_session)
# Create multiple predictions for the same match
import time
prediction1 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0,
away_energy=40.0
)
time.sleep(0.01)
prediction2 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=60.0,
away_energy=35.0
)
time.sleep(0.01)
prediction3 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=70.0,
away_energy=30.0
)
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
assert response.status_code == 200
data = response.json()
# Should return latest prediction
assert data["id"] == prediction3.id
# History should contain all predictions
assert len(data["history"]) == 3
assert data["history"][0]["id"] == prediction3.id
assert data["history"][1]["id"] == prediction2.id
assert data["history"][2]["id"] == prediction1.id
def test_get_prediction_by_match_id_not_found(self, client: TestClient):
"""Test getting prediction for non-existent match."""
response = client.get("/api/v1/predictions/match/999")
assert response.status_code == 404
assert "No predictions found" in response.json()["detail"]
def test_get_prediction_by_match_id_no_predictions(self, client: TestClient, sample_match: Match):
"""Test getting prediction for match without predictions."""
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
assert response.status_code == 404

View File

@@ -0,0 +1,212 @@
"""
Unit tests for Prediction Calculator module.
"""
import pytest
from app.ml.prediction_calculator import (
calculate_prediction,
calculate_confidence_meter,
determine_winner,
validate_prediction_result
)
class TestConfidenceMeterCalculation:
"""Test Confidence Meter calculation logic."""
def test_confidence_meter_equal_energy(self):
"""Test confidence meter when both teams have equal energy."""
result = calculate_confidence_meter(home_energy=50.0, away_energy=50.0)
assert result == 0.0 # No confidence when teams are equal
def test_confidence_meter_small_difference(self):
"""Test confidence meter with small energy difference."""
result = calculate_confidence_meter(home_energy=52.0, away_energy=50.0)
assert 0 < result < 100
assert result == 4.0 # 2 * 2.0 = 4.0%
def test_confidence_meter_medium_difference(self):
"""Test confidence meter with medium energy difference."""
result = calculate_confidence_meter(home_energy=60.0, away_energy=50.0)
assert 0 < result < 100
assert result == 20.0 # 2 * 10.0 = 20.0%
def test_confidence_meter_large_difference(self):
"""Test confidence meter with large energy difference."""
result = calculate_confidence_meter(home_energy=80.0, away_energy=50.0)
assert 0 < result <= 100
assert result == 60.0 # 2 * 30.0 = 60.0%
def test_confidence_meter_very_large_difference(self):
"""Test confidence meter caps at 100% for very large differences."""
result = calculate_confidence_meter(home_energy=100.0, away_energy=50.0)
assert result == 100.0 # Capped at 100%
def test_confidence_meter_zero_energy(self):
"""Test confidence meter when one team has zero energy."""
result = calculate_confidence_meter(home_energy=0.0, away_energy=50.0)
assert result == 100.0 # Max confidence
def test_confidence_meter_negative_energy(self):
"""Test confidence meter with negative energy values."""
result = calculate_confidence_meter(home_energy=-10.0, away_energy=50.0)
assert result == 100.0 # Capped at 100%
class TestWinnerDetermination:
"""Test winner determination logic."""
def test_home_team_wins(self):
"""Test home team wins when home energy is higher."""
result = determine_winner(home_energy=60.0, away_energy=40.0)
assert result == "home"
def test_away_team_wins(self):
"""Test away team wins when away energy is higher."""
result = determine_winner(home_energy=40.0, away_energy=60.0)
assert result == "away"
def test_draw_equal_energy(self):
"""Test draw when both teams have equal energy."""
result = determine_winner(home_energy=50.0, away_energy=50.0)
assert result == "draw"
def test_draw_almost_equal(self):
"""Test draw when energy difference is very small."""
result = determine_winner(home_energy=50.01, away_energy=50.0)
assert result == "home" # Should prefer home on tie
class TestPredictionCalculation:
"""Test complete prediction calculation."""
def test_prediction_home_win(self):
"""Test prediction calculation for home team win."""
result = calculate_prediction(home_energy=65.0, away_energy=45.0)
assert 'confidence' in result
assert 'predicted_winner' in result
assert 'home_energy' in result
assert 'away_energy' in result
assert result['confidence'] > 0
assert result['predicted_winner'] == 'home'
assert result['home_energy'] == 65.0
assert result['away_energy'] == 45.0
def test_prediction_away_win(self):
"""Test prediction calculation for away team win."""
result = calculate_prediction(home_energy=35.0, away_energy=70.0)
assert result['confidence'] > 0
assert result['predicted_winner'] == 'away'
assert result['home_energy'] == 35.0
assert result['away_energy'] == 70.0
def test_prediction_draw(self):
"""Test prediction calculation for draw."""
result = calculate_prediction(home_energy=50.0, away_energy=50.0)
assert result['confidence'] == 0.0
assert result['predicted_winner'] == 'draw'
def test_prediction_high_confidence(self):
"""Test prediction with high confidence."""
result = calculate_prediction(home_energy=90.0, away_energy=30.0)
assert result['confidence'] >= 80.0 # High confidence
assert result['predicted_winner'] == 'home'
class TestPredictionValidation:
"""Test prediction result validation."""
def test_valid_prediction_result(self):
"""Test validation of valid prediction result."""
result = {
'confidence': 75.0,
'predicted_winner': 'home',
'home_energy': 65.0,
'away_energy': 45.0
}
assert validate_prediction_result(result) is True
def test_invalid_confidence_negative(self):
"""Test validation fails with negative confidence."""
result = {
'confidence': -10.0,
'predicted_winner': 'home',
'home_energy': 65.0,
'away_energy': 45.0
}
assert validate_prediction_result(result) is False
def test_invalid_confidence_over_100(self):
"""Test validation fails with confidence over 100%."""
result = {
'confidence': 150.0,
'predicted_winner': 'home',
'home_energy': 65.0,
'away_energy': 45.0
}
assert validate_prediction_result(result) is False
def test_invalid_winner_value(self):
"""Test validation fails with invalid winner value."""
result = {
'confidence': 75.0,
'predicted_winner': 'invalid',
'home_energy': 65.0,
'away_energy': 45.0
}
assert validate_prediction_result(result) is False
def test_invalid_missing_fields(self):
"""Test validation fails when required fields are missing."""
result = {
'confidence': 75.0,
'predicted_winner': 'home'
# Missing home_energy and away_energy
}
assert validate_prediction_result(result) is False
def test_invalid_energy_negative(self):
"""Test validation fails with negative energy values."""
result = {
'confidence': 75.0,
'predicted_winner': 'home',
'home_energy': -10.0,
'away_energy': 45.0
}
assert validate_prediction_result(result) is False
class TestPredictionEdgeCases:
"""Test prediction calculation edge cases."""
def test_both_teams_zero_energy(self):
"""Test prediction when both teams have zero energy."""
result = calculate_prediction(home_energy=0.0, away_energy=0.0)
assert result['confidence'] == 0.0
assert result['predicted_winner'] == 'draw'
def test_very_high_energy_values(self):
"""Test prediction with very high energy values."""
result = calculate_prediction(home_energy=1000.0, away_energy=500.0)
assert result['confidence'] == 100.0 # Should be capped
assert result['predicted_winner'] == 'home'
def test_decimal_energy_values(self):
"""Test prediction with decimal energy values."""
result = calculate_prediction(home_energy=55.5, away_energy=54.5)
assert 0 < result['confidence'] < 100
assert result['predicted_winner'] == 'home'

View File

@@ -0,0 +1,255 @@
"""
Unit tests for Prediction model.
"""
import pytest
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from app.models.match import Match
from app.models.prediction import Prediction
from app.database import Base, engine, SessionLocal
@pytest.fixture(scope="function")
def db_session():
"""Create a fresh database session for each test."""
# Create tables
Base.metadata.create_all(bind=engine)
# Create session
session = SessionLocal()
try:
yield session
session.rollback()
finally:
session.close()
Base.metadata.drop_all(bind=engine)
class TestPredictionModel:
"""Test Prediction SQLAlchemy model."""
def test_prediction_creation(self, db_session: Session):
"""Test creating a prediction in database."""
# First create a match
match = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
# Create prediction
prediction = Prediction(
match_id=match.id,
energy_score="high",
confidence="85%",
predicted_winner="PSG",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction)
db_session.commit()
db_session.refresh(prediction)
assert prediction.id is not None
assert prediction.match_id == match.id
assert prediction.energy_score == "high"
assert prediction.confidence == "85%"
assert prediction.predicted_winner == "PSG"
def test_prediction_required_fields(self, db_session: Session):
"""Test that all required fields must be provided."""
# Create a match first
match = Match(
home_team="Barcelona",
away_team="Real Madrid",
date=datetime.now(timezone.utc),
league="La Liga",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
# Missing predicted_winner
prediction = Prediction(
match_id=match.id,
energy_score="high",
confidence="90%",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction)
with pytest.raises(Exception): # IntegrityError expected
db_session.commit()
def test_prediction_foreign_key_constraint(self, db_session: Session):
"""Test that match_id must reference an existing match."""
prediction = Prediction(
match_id=999, # Non-existent match
energy_score="high",
confidence="90%",
predicted_winner="PSG",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction)
with pytest.raises(Exception): # IntegrityError expected
db_session.commit()
def test_prediction_to_dict(self, db_session: Session):
"""Test converting prediction to dictionary."""
match = Match(
home_team="Manchester City",
away_team="Liverpool",
date=datetime.now(timezone.utc),
league="Premier League",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
prediction = Prediction(
match_id=match.id,
energy_score="medium",
confidence="70%",
predicted_winner="Liverpool",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction)
db_session.commit()
db_session.refresh(prediction)
prediction_dict = prediction.to_dict()
assert prediction_dict['match_id'] == match.id
assert prediction_dict['energy_score'] == "medium"
assert prediction_dict['confidence'] == "70%"
assert prediction_dict['predicted_winner'] == "Liverpool"
assert 'id' in prediction_dict
assert 'created_at' in prediction_dict
def test_prediction_repr(self, db_session: Session):
"""Test prediction __repr__ method."""
match = Match(
home_team="Juventus",
away_team="Inter Milan",
date=datetime.now(timezone.utc),
league="Serie A",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
prediction = Prediction(
match_id=match.id,
energy_score="low",
confidence="50%",
predicted_winner="Juventus",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction)
db_session.commit()
db_session.refresh(prediction)
repr_str = repr(prediction)
assert "Prediction" in repr_str
assert "id=" in repr_str
assert "match_id=" in repr_str
assert "confidence=50%" in repr_str
def test_prediction_match_relationship(self, db_session: Session):
"""Test prediction relationship with match."""
match = Match(
home_team="Bayern Munich",
away_team="Borussia Dortmund",
date=datetime.now(timezone.utc),
league="Bundesliga",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
prediction = Prediction(
match_id=match.id,
energy_score="very_high",
confidence="95%",
predicted_winner="Bayern Munich",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction)
db_session.commit()
db_session.refresh(prediction)
# Access relationship
assert prediction.match.id == match.id
assert prediction.match.home_team == "Bayern Munich"
assert prediction.match.away_team == "Borussia Dortmund"
def test_multiple_predictions_per_match(self, db_session: Session):
"""Test that a match can have multiple predictions."""
match = Match(
home_team="Ajax",
away_team="Feyenoord",
date=datetime.now(timezone.utc),
league="Eredivisie",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
# Create multiple predictions
prediction1 = Prediction(
match_id=match.id,
energy_score="high",
confidence="80%",
predicted_winner="Ajax",
created_at=datetime.now(timezone.utc)
)
prediction2 = Prediction(
match_id=match.id,
energy_score="medium",
confidence="60%",
predicted_winner="Ajax",
created_at=datetime.now(timezone.utc)
)
prediction3 = Prediction(
match_id=match.id,
energy_score="low",
confidence="40%",
predicted_winner="Feyenoord",
created_at=datetime.now(timezone.utc)
)
db_session.add(prediction1)
db_session.add(prediction2)
db_session.add(prediction3)
db_session.commit()
# Refresh match and check predictions
db_session.refresh(match)
assert len(match.predictions) == 3

View File

@@ -0,0 +1,211 @@
"""
Unit tests for Prediction Pydantic schemas.
"""
import pytest
from datetime import datetime, timezone
from pydantic import ValidationError
from app.schemas.prediction import (
PredictionBase,
PredictionCreate,
PredictionUpdate,
PredictionResponse,
PredictionListResponse,
PredictionStatsResponse
)
class TestPredictionBase:
"""Test PredictionBase schema."""
def test_prediction_base_valid(self):
"""Test creating a valid PredictionBase."""
prediction_data = {
"match_id": 1,
"energy_score": "high",
"confidence": "85%",
"predicted_winner": "PSG",
"created_at": datetime.now(timezone.utc)
}
prediction = PredictionBase(**prediction_data)
assert prediction.match_id == 1
assert prediction.energy_score == "high"
assert prediction.confidence == "85%"
assert prediction.predicted_winner == "PSG"
def test_prediction_base_energy_score_too_long(self):
"""Test that energy_score exceeds max length."""
with pytest.raises(ValidationError) as exc_info:
PredictionBase(
match_id=1,
energy_score="A" * 51, # Too long
confidence="85%",
predicted_winner="PSG",
created_at=datetime.now(timezone.utc)
)
assert "at most 50 characters" in str(exc_info.value).lower()
def test_prediction_base_confidence_too_long(self):
"""Test that confidence exceeds max length."""
with pytest.raises(ValidationError) as exc_info:
PredictionBase(
match_id=1,
energy_score="high",
confidence="A" * 51, # Too long
predicted_winner="PSG",
created_at=datetime.now(timezone.utc)
)
assert "at most 50 characters" in str(exc_info.value).lower()
def test_prediction_base_predicted_winner_too_long(self):
"""Test that predicted_winner exceeds max length."""
with pytest.raises(ValidationError) as exc_info:
PredictionBase(
match_id=1,
energy_score="high",
confidence="85%",
predicted_winner="A" * 256, # Too long
created_at=datetime.now(timezone.utc)
)
assert "at most 255 characters" in str(exc_info.value).lower()
class TestPredictionCreate:
"""Test PredictionCreate schema."""
def test_prediction_create_valid(self):
"""Test creating a valid PredictionCreate."""
prediction_data = {
"match_id": 1,
"energy_score": "medium",
"confidence": "70%",
"predicted_winner": "Barcelona",
"created_at": datetime.now(timezone.utc)
}
prediction = PredictionCreate(**prediction_data)
assert prediction.match_id == 1
assert prediction.energy_score == "medium"
assert prediction.confidence == "70%"
assert prediction.predicted_winner == "Barcelona"
class TestPredictionUpdate:
"""Test PredictionUpdate schema."""
def test_prediction_update_partial(self):
"""Test updating only some fields."""
update_data = {
"confidence": "90%"
}
prediction_update = PredictionUpdate(**update_data)
assert prediction_update.confidence == "90%"
assert prediction_update.energy_score is None
assert prediction_update.predicted_winner is None
def test_prediction_update_all_fields(self):
"""Test updating all fields."""
update_data = {
"energy_score": "very_high",
"confidence": "95%",
"predicted_winner": "Real Madrid"
}
prediction_update = PredictionUpdate(**update_data)
assert prediction_update.energy_score == "very_high"
assert prediction_update.confidence == "95%"
assert prediction_update.predicted_winner == "Real Madrid"
def test_prediction_update_empty(self):
"""Test that PredictionUpdate can be empty."""
prediction_update = PredictionUpdate()
assert prediction_update.energy_score is None
assert prediction_update.confidence is None
assert prediction_update.predicted_winner is None
class TestPredictionResponse:
"""Test PredictionResponse schema."""
def test_prediction_response_from_dict(self):
"""Test creating PredictionResponse from dictionary."""
prediction_dict = {
"id": 1,
"match_id": 1,
"energy_score": "high",
"confidence": "85%",
"predicted_winner": "Manchester City",
"created_at": datetime.now(timezone.utc)
}
prediction = PredictionResponse(**prediction_dict)
assert prediction.id == 1
assert prediction.match_id == 1
assert prediction.energy_score == "high"
assert prediction.confidence == "85%"
assert prediction.predicted_winner == "Manchester City"
class TestPredictionListResponse:
"""Test PredictionListResponse schema."""
def test_prediction_list_response(self):
"""Test creating a PredictionListResponse."""
predictions_data = [
{
"id": 1,
"match_id": 1,
"energy_score": "high",
"confidence": "85%",
"predicted_winner": "PSG",
"created_at": datetime.now(timezone.utc)
},
{
"id": 2,
"match_id": 2,
"energy_score": "medium",
"confidence": "70%",
"predicted_winner": "Barcelona",
"created_at": datetime.now(timezone.utc)
}
]
response = PredictionListResponse(data=predictions_data, count=2, meta={"page": 1})
assert response.count == 2
assert len(response.data) == 2
assert response.meta["page"] == 1
class TestPredictionStatsResponse:
"""Test PredictionStatsResponse schema."""
def test_prediction_stats_response(self):
"""Test creating a PredictionStatsResponse."""
stats = {
"total_predictions": 20,
"predictions_by_confidence": {"high": 10, "medium": 5, "low": 5},
"predictions_by_energy_score": {"high": 8, "medium": 7, "low": 5},
"avg_confidence": 75.5,
"unique_matches_predicted": 15
}
prediction_stats = PredictionStatsResponse(**stats)
assert prediction_stats.total_predictions == 20
assert prediction_stats.predictions_by_confidence["high"] == 10
assert prediction_stats.predictions_by_energy_score["medium"] == 7
assert prediction_stats.avg_confidence == 75.5
assert prediction_stats.unique_matches_predicted == 15

View File

@@ -0,0 +1,258 @@
"""
Unit tests for Prediction Service.
"""
import pytest
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from app.database import Base, engine, SessionLocal
from app.models.match import Match
from app.models.prediction import Prediction
from app.services.prediction_service import PredictionService
@pytest.fixture(scope="function")
def db_session():
"""Create a fresh database session for each test."""
Base.metadata.create_all(bind=engine)
session = SessionLocal()
try:
yield session
session.rollback()
finally:
session.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture
def sample_match(db_session: Session) -> Match:
"""Create a sample match for testing."""
match = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.now(timezone.utc),
league="Ligue 1",
status="scheduled"
)
db_session.add(match)
db_session.commit()
db_session.refresh(match)
return match
class TestPredictionServiceCreate:
"""Test prediction service creation methods."""
def test_create_prediction_home_win(self, db_session: Session, sample_match: Match):
"""Test creating a prediction where home team wins."""
service = PredictionService(db_session)
prediction = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=65.0,
away_energy=45.0,
energy_score_label="high"
)
assert prediction.id is not None
assert prediction.match_id == sample_match.id
assert prediction.predicted_winner == "PSG"
assert prediction.energy_score == "high"
assert "%" in prediction.confidence
def test_create_prediction_away_win(self, db_session: Session, sample_match: Match):
"""Test creating a prediction where away team wins."""
service = PredictionService(db_session)
prediction = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=35.0,
away_energy=70.0
)
assert prediction.predicted_winner == "Olympique de Marseille"
assert prediction.id is not None
def test_create_prediction_draw(self, db_session: Session, sample_match: Match):
"""Test creating a prediction for draw."""
service = PredictionService(db_session)
prediction = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0,
away_energy=50.0
)
assert prediction.predicted_winner == "Draw"
assert prediction.confidence == "0.0%"
def test_create_prediction_with_default_energy_label(self, db_session: Session, sample_match: Match):
"""Test creating prediction without energy score label."""
service = PredictionService(db_session)
# High energy prediction
prediction1 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=80.0,
away_energy=70.0
)
assert prediction1.energy_score in ["high", "very_high"]
# Low energy prediction
prediction2 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=20.0,
away_energy=10.0
)
assert prediction2.energy_score == "low"
def test_create_prediction_nonexistent_match(self, db_session: Session):
"""Test creating prediction for non-existent match raises error."""
service = PredictionService(db_session)
with pytest.raises(ValueError, match="Match with id 999 not found"):
service.create_prediction_for_match(
match_id=999,
home_energy=65.0,
away_energy=45.0
)
def test_create_prediction_negative_energy(self, db_session: Session, sample_match: Match):
"""Test creating prediction with negative energy raises error."""
service = PredictionService(db_session)
with pytest.raises(ValueError, match="cannot be negative"):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=-10.0,
away_energy=45.0
)
def test_create_prediction_invalid_energy_type(self, db_session: Session, sample_match: Match):
"""Test creating prediction with invalid energy type raises error."""
service = PredictionService(db_session)
with pytest.raises(ValueError, match="must be numeric"):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy="invalid",
away_energy=45.0
)
class TestPredictionServiceRetrieve:
"""Test prediction service retrieval methods."""
def test_get_prediction_by_id(self, db_session: Session, sample_match: Match):
"""Test retrieving a prediction by ID."""
service = PredictionService(db_session)
created = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=65.0,
away_energy=45.0
)
retrieved = service.get_prediction_by_id(created.id)
assert retrieved is not None
assert retrieved.id == created.id
assert retrieved.match_id == created.match_id
def test_get_prediction_by_id_not_found(self, db_session: Session):
"""Test retrieving non-existent prediction returns None."""
service = PredictionService(db_session)
retrieved = service.get_prediction_by_id(999)
assert retrieved is None
def test_get_predictions_for_match(self, db_session: Session, sample_match: Match):
"""Test retrieving all predictions for a match."""
service = PredictionService(db_session)
# Create multiple predictions
for i in range(3):
service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0 + i * 10,
away_energy=40.0 + i * 5
)
predictions = service.get_predictions_for_match(sample_match.id)
assert len(predictions) == 3
assert all(p.match_id == sample_match.id for p in predictions)
def test_get_predictions_for_empty_match(self, db_session: Session, sample_match: Match):
"""Test retrieving predictions for match with no predictions."""
service = PredictionService(db_session)
predictions = service.get_predictions_for_match(sample_match.id)
assert len(predictions) == 0
def test_get_latest_prediction_for_match(self, db_session: Session, sample_match: Match):
"""Test retrieving latest prediction for a match."""
service = PredictionService(db_session)
# Create predictions at different times
prediction1 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=50.0,
away_energy=40.0
)
# Small delay to ensure different timestamps
import time
time.sleep(0.01)
prediction2 = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=60.0,
away_energy=35.0
)
latest = service.get_latest_prediction_for_match(sample_match.id)
assert latest.id == prediction2.id
assert latest.id != prediction1.id
def test_get_latest_prediction_for_empty_match(self, db_session: Session, sample_match: Match):
"""Test retrieving latest prediction for match with no predictions."""
service = PredictionService(db_session)
latest = service.get_latest_prediction_for_match(sample_match.id)
assert latest is None
class TestPredictionServiceDelete:
"""Test prediction service delete methods."""
def test_delete_prediction(self, db_session: Session, sample_match: Match):
"""Test deleting a prediction."""
service = PredictionService(db_session)
prediction = service.create_prediction_for_match(
match_id=sample_match.id,
home_energy=65.0,
away_energy=45.0
)
deleted = service.delete_prediction(prediction.id)
assert deleted is True
# Verify prediction is gone
retrieved = service.get_prediction_by_id(prediction.id)
assert retrieved is None
def test_delete_prediction_not_found(self, db_session: Session):
"""Test deleting non-existent prediction."""
service = PredictionService(db_session)
deleted = service.delete_prediction(999)
assert deleted is False

View File

@@ -0,0 +1,178 @@
"""
Tests for public API endpoints.
This module tests the public API endpoints for predictions and matches.
"""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.main import app
from app.database import get_db
from app.models.prediction import Prediction
from app.models.match import Match
from datetime import datetime
client = TestClient(app)
@pytest.fixture
def sample_match(db: Session):
"""Create a sample match for testing."""
match = Match(
home_team="PSG",
away_team="Olympique de Marseille",
date=datetime.utcnow(),
league="Ligue 1",
status="scheduled"
)
db.add(match)
db.commit()
db.refresh(match)
return match
@pytest.fixture
def sample_prediction(db: Session, sample_match: Match):
"""Create a sample prediction for testing."""
prediction = Prediction(
match_id=sample_match.id,
energy_score="high",
confidence="70.5%",
predicted_winner="PSG",
created_at=datetime.utcnow()
)
db.add(prediction)
db.commit()
db.refresh(prediction)
return prediction
class TestPublicPredictionsEndpoint:
"""Tests for GET /api/public/v1/predictions endpoint."""
def test_get_public_predictions_success(self, sample_prediction: Prediction):
"""Test successful retrieval of public predictions."""
response = client.get("/api/public/v1/predictions")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "meta" in data
assert isinstance(data["data"], list)
assert len(data["data"]) > 0
assert data["meta"]["version"] == "v1"
def test_get_public_predictions_with_limit(self, sample_prediction: Prediction):
"""Test retrieval of public predictions with limit parameter."""
response = client.get("/api/public/v1/predictions?limit=5")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "meta" in data
assert data["meta"]["limit"] == 5
def test_get_public_predictions_data_structure(self, sample_prediction: Prediction):
"""Test that public predictions have correct structure without sensitive data."""
response = client.get("/api/public/v1/predictions")
assert response.status_code == 200
data = response.json()
prediction = data["data"][0]
# Required fields
assert "id" in prediction
assert "match" in prediction
assert "energy_score" in prediction
assert "confidence" in prediction
assert "predicted_winner" in prediction
# Match details
match = prediction["match"]
assert "id" in match
assert "home_team" in match
assert "away_team" in match
assert "date" in match
assert "league" in match
assert "status" in match
def test_get_public_predictions_empty_database(self):
"""Test retrieval when no predictions exist."""
# Clear database for this test
response = client.get("/api/public/v1/predictions")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert isinstance(data["data"], list)
class TestPublicMatchesEndpoint:
"""Tests for GET /api/public/v1/matches endpoint."""
def test_get_public_matches_success(self, sample_match: Match):
"""Test successful retrieval of public matches."""
response = client.get("/api/public/v1/matches")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "meta" in data
assert isinstance(data["data"], list)
assert len(data["data"]) > 0
assert data["meta"]["version"] == "v1"
def test_get_public_matches_with_filters(self, sample_match: Match):
"""Test retrieval of public matches with league filter."""
response = client.get(f"/api/public/v1/matches?league=Ligue%201")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "meta" in data
def test_get_public_matches_data_structure(self, sample_match: Match):
"""Test that public matches have correct structure."""
response = client.get("/api/public/v1/matches")
assert response.status_code == 200
data = response.json()
match = data["data"][0]
# Required fields (public only, no sensitive data)
assert "id" in match
assert "home_team" in match
assert "away_team" in match
assert "date" in match
assert "league" in match
assert "status" in match
def test_get_public_matches_empty_database(self):
"""Test retrieval when no matches exist."""
response = client.get("/api/public/v1/matches")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert isinstance(data["data"], list)
class TestPublicApiResponseFormat:
"""Tests for standardized API response format."""
def test_response_format_success(self, sample_prediction: Prediction):
"""Test that successful responses follow {data, meta} format."""
response = client.get("/api/public/v1/predictions")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "meta" in data
assert "timestamp" in data["meta"]
assert "version" in data["meta"]
def test_response_meta_timestamp_format(self, sample_prediction: Prediction):
"""Test that timestamp is in ISO 8601 format."""
response = client.get("/api/public/v1/predictions")
assert response.status_code == 200
data = response.json()
timestamp = data["meta"]["timestamp"]
# ISO 8601 format check (simplified)
assert "T" in timestamp
assert "Z" in timestamp

View File

@@ -0,0 +1,223 @@
"""
Tests for RabbitMQ client.
"""
import pytest
import json
from unittest.mock import Mock, patch, MagicMock
import pika
from app.queues.rabbitmq_client import RabbitMQClient, create_rabbitmq_client
class TestRabbitMQClient:
"""Tests for RabbitMQClient class."""
def test_initialization(self):
"""Test RabbitMQ client initialization."""
client = RabbitMQClient(
rabbitmq_url="amqp://guest:guest@localhost:5672",
prefetch_count=1
)
assert client.rabbitmq_url == "amqp://guest:guest@localhost:5672"
assert client.prefetch_count == 1
assert client.connection is None
assert client.channel is None
assert client.queues == {
'scraping_tasks': 'scraping_tasks',
'sentiment_analysis_tasks': 'sentiment_analysis_tasks',
'energy_calculation_tasks': 'energy_calculation_tasks',
'results': 'results'
}
@patch('pika.BlockingConnection')
def test_connect_success(self, mock_connection):
"""Test successful connection to RabbitMQ."""
# Setup mocks
mock_conn_instance = Mock()
mock_channel_instance = Mock()
mock_connection.return_value = mock_conn_instance
mock_conn_instance.channel.return_value = mock_channel_instance
# Create client and connect
client = RabbitMQClient()
client.connect()
# Verify connection and channel created
mock_connection.assert_called_once()
mock_conn_instance.channel.assert_called_once()
mock_channel_instance.basic_qos.assert_called_once_with(prefetch_count=1)
# Verify queues declared
assert mock_channel_instance.queue_declare.call_count == 4
# Verify client state
assert client.connection == mock_conn_instance
assert client.channel == mock_channel_instance
@patch('pika.BlockingConnection')
def test_publish_message(self, mock_connection):
"""Test publishing a message to a queue."""
# Setup mocks
mock_conn_instance = Mock()
mock_channel_instance = Mock()
mock_connection.return_value = mock_conn_instance
mock_conn_instance.channel.return_value = mock_channel_instance
# Create client and connect
client = RabbitMQClient()
client.connect()
# Publish message
client.publish_message(
queue_name='scraping_tasks',
data={'match_id': 123, 'source': 'twitter'},
event_type='scraping.task.created'
)
# Verify publish called
mock_channel_instance.basic_publish.assert_called_once()
call_args = mock_channel_instance.basic_publish.call_args
# Check routing key
assert call_args[1]['routing_key'] == 'scraping_tasks'
# Check message body is valid JSON
message_body = call_args[1]['body']
message = json.loads(message_body)
assert message['event'] == 'scraping.task.created'
assert message['version'] == '1.0'
assert message['data']['match_id'] == 123
assert message['data']['source'] == 'twitter'
assert 'timestamp' in message
# Check message properties
properties = call_args[1]['properties']
assert properties.delivery_mode == 2 # Persistent
@patch('pika.BlockingConnection')
def test_consume_messages(self, mock_connection):
"""Test consuming messages from a queue."""
# Setup mocks
mock_conn_instance = Mock()
mock_channel_instance = Mock()
mock_connection.return_value = mock_conn_instance
mock_conn_instance.channel.return_value = mock_channel_instance
# Create client and connect
client = RabbitMQClient()
client.connect()
# Mock start_consuming to avoid blocking
mock_channel_instance.start_consuming = Mock(side_effect=KeyboardInterrupt())
# Define callback
callback = Mock()
# Consume messages
try:
client.consume_messages(queue_name='scraping_tasks', callback=callback)
except KeyboardInterrupt:
pass
# Verify consume called
mock_channel_instance.basic_consume.assert_called_once()
call_args = mock_channel_instance.basic_consume.call_args
assert call_args[1]['queue'] == 'scraping_tasks'
assert call_args[1]['auto_ack'] is False
assert call_args[1]['on_message_callback'] == callback
# Verify start_consuming called
mock_channel_instance.start_consuming.assert_called_once()
@patch('pika.BlockingConnection')
def test_ack_message(self, mock_connection):
"""Test acknowledging a message."""
# Setup mocks
mock_conn_instance = Mock()
mock_channel_instance = Mock()
mock_connection.return_value = mock_conn_instance
mock_conn_instance.channel.return_value = mock_channel_instance
# Create client and connect
client = RabbitMQClient()
client.connect()
# Acknowledge message
client.ack_message(delivery_tag=123)
# Verify ack called
mock_channel_instance.basic_ack.assert_called_once_with(delivery_tag=123)
@patch('pika.BlockingConnection')
def test_reject_message(self, mock_connection):
"""Test rejecting a message."""
# Setup mocks
mock_conn_instance = Mock()
mock_channel_instance = Mock()
mock_connection.return_value = mock_conn_instance
mock_conn_instance.channel.return_value = mock_channel_instance
# Create client and connect
client = RabbitMQClient()
client.connect()
# Reject message without requeue
client.reject_message(delivery_tag=123, requeue=False)
# Verify reject called
mock_channel_instance.basic_reject.assert_called_once_with(
delivery_tag=123,
requeue=False
)
@patch('pika.BlockingConnection')
def test_close_connection(self, mock_connection):
"""Test closing connection to RabbitMQ."""
# Setup mocks
mock_conn_instance = Mock()
mock_channel_instance = Mock()
mock_connection.return_value = mock_conn_instance
mock_conn_instance.channel.return_value = mock_channel_instance
# Create client and connect
client = RabbitMQClient()
client.connect()
# Close connection
client.close()
# Verify channel and connection closed
mock_channel_instance.close.assert_called_once()
mock_conn_instance.close.assert_called_once()
class TestCreateRabbitMQClient:
"""Tests for create_rabbitmq_client factory function."""
def test_create_with_defaults(self):
"""Test creating client with default parameters."""
client = create_rabbitmq_client()
assert client.rabbitmq_url == "amqp://guest:guest@localhost:5672"
assert client.prefetch_count == 1
assert isinstance(client, RabbitMQClient)
def test_create_with_custom_url(self):
"""Test creating client with custom URL."""
client = create_rabbitmq_client(
rabbitmq_url="amqp://user:pass@remote:5672"
)
assert client.rabbitmq_url == "amqp://user:pass@remote:5672"
assert client.prefetch_count == 1
def test_create_with_custom_prefetch(self):
"""Test creating client with custom prefetch count."""
client = create_rabbitmq_client(prefetch_count=5)
assert client.rabbitmq_url == "amqp://guest:guest@localhost:5672"
assert client.prefetch_count == 5

View File

@@ -0,0 +1,146 @@
"""
Tests for RabbitMQ message consumers.
"""
import pytest
import json
from unittest.mock import Mock, patch, MagicMock
from sqlalchemy.orm import Session
from app.queues.consumers import (
consume_scraping_tasks,
consume_sentiment_analysis_tasks,
consume_energy_calculation_tasks,
consume_results
)
class TestConsumeScrapingTasks:
"""Tests for consume_scraping_tasks function."""
@patch('app.queues.consumers.client')
@patch('app.queues.consumers.consume_messages')
def test_consume_scraping_tasks_success(self, mock_consume, mock_client):
"""Test successful scraping task consumption."""
# Setup mocks
mock_db_session = Mock(spec=Session)
mock_db_factory = Mock(return_value=mock_db_session)
mock_callback = Mock(return_value={'collected_count': 50})
# Call consume
consume_scraping_tasks(
client=mock_client,
callback=mock_callback,
db_session_factory=mock_db_factory
)
# Verify consume_messages called
mock_consume.assert_called_once()
call_args = mock_consume.call_args
assert call_args[0][0] == 'scraping_tasks'
assert call_args[1]['callback'] is not None
@patch('app.queues.consumers.client')
@patch('app.queues.consumers.consume_messages')
def test_consume_scraping_tasks_error_handling(self, mock_consume, mock_client):
"""Test error handling in scraping task consumption."""
mock_db_session = Mock(spec=Session)
mock_db_factory = Mock(return_value=mock_db_session)
mock_callback = Mock(side_effect=Exception("Processing error"))
# Get the callback function
def capture_callback(ch, method, properties, body):
message = json.loads(body)
mock_callback(message.get('data', {}), mock_db_session)
# Call consume
consume_scraping_tasks(
client=mock_client,
callback=mock_callback,
db_session_factory=mock_db_factory
)
# Verify consume_messages called
mock_consume.assert_called_once()
class TestConsumeSentimentAnalysisTasks:
"""Tests for consume_sentiment_analysis_tasks function."""
@patch('app.queues.consumers.client')
@patch('app.queues.consumers.consume_messages')
def test_consume_sentiment_analysis_tasks_success(self, mock_consume, mock_client):
"""Test successful sentiment analysis task consumption."""
# Setup mocks
mock_db_session = Mock(spec=Session)
mock_db_factory = Mock(return_value=mock_db_session)
mock_callback = Mock(return_value={'analyzed_count': 100})
# Call consume
consume_sentiment_analysis_tasks(
client=mock_client,
callback=mock_callback,
db_session_factory=mock_db_factory
)
# Verify consume_messages called
mock_consume.assert_called_once()
call_args = mock_consume.call_args
assert call_args[0][0] == 'sentiment_analysis_tasks'
assert call_args[1]['callback'] is not None
class TestConsumeEnergyCalculationTasks:
"""Tests for consume_energy_calculation_tasks function."""
@patch('app.queues.consumers.client')
@patch('app.queues.consumers.consume_messages')
def test_consume_energy_calculation_tasks_success(self, mock_consume, mock_client):
"""Test successful energy calculation task consumption."""
# Setup mocks
mock_db_session = Mock(spec=Session)
mock_db_factory = Mock(return_value=mock_db_session)
mock_callback = Mock(return_value={
'energy_score': 75.5,
'confidence': 0.82
})
# Call consume
consume_energy_calculation_tasks(
client=mock_client,
callback=mock_callback,
db_session_factory=mock_db_factory
)
# Verify consume_messages called
mock_consume.assert_called_once()
call_args = mock_consume.call_args
assert call_args[0][0] == 'energy_calculation_tasks'
assert call_args[1]['callback'] is not None
class TestConsumeResults:
"""Tests for consume_results function."""
@patch('app.queues.consumers.client')
@patch('app.queues.consumers.consume_messages')
def test_consume_results_success(self, mock_consume, mock_client):
"""Test successful result consumption."""
# Setup mocks
mock_callback = Mock()
# Call consume
consume_results(
client=mock_client,
callback=mock_callback
)
# Verify consume_messages called
mock_consume.assert_called_once()
call_args = mock_consume.call_args
assert call_args[0][0] == 'results'
assert call_args[1]['callback'] is not None

View File

@@ -0,0 +1,266 @@
"""
Tests for RabbitMQ message producers.
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime
from app.queues.producers import (
publish_scraping_task,
publish_sentiment_analysis_task,
publish_energy_calculation_task,
publish_result,
publish_scraping_result,
publish_sentiment_analysis_result,
publish_energy_calculation_result
)
class TestPublishScrapingTask:
"""Tests for publish_scraping_task function."""
@patch('app.queues.producers.datetime')
def test_publish_scraping_task_twitter(self, mock_datetime):
"""Test publishing a Twitter scraping task."""
# Mock datetime
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
# Create mock client
mock_client = Mock()
mock_client.publish_message = Mock()
# Publish task
publish_scraping_task(
client=mock_client,
match_id=123,
source='twitter',
keywords=['#MatchName', 'team1 vs team2'],
priority='normal'
)
# Verify publish_message called
mock_client.publish_message.assert_called_once()
call_args = mock_client.publish_message.call_args
# Check parameters
assert call_args[1]['queue_name'] == 'scraping_tasks'
assert call_args[1]['event_type'] == 'scraping.task.created'
# Check task data
task_data = call_args[1]['data']
assert task_data['task_type'] == 'scraping'
assert task_data['match_id'] == 123
assert task_data['source'] == 'twitter'
assert task_data['keywords'] == ['#MatchName', 'team1 vs team2']
assert task_data['priority'] == 'normal'
assert 'created_at' in task_data
@patch('app.queues.producers.datetime')
def test_publish_scraping_task_vip(self, mock_datetime):
"""Test publishing a VIP scraping task."""
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
mock_client = Mock()
mock_client.publish_message = Mock()
publish_scraping_task(
client=mock_client,
match_id=456,
source='reddit',
keywords=['Ligue1'],
priority='vip'
)
call_args = mock_client.publish_message.call_args
task_data = call_args[1]['data']
assert task_data['priority'] == 'vip'
assert task_data['source'] == 'reddit'
class TestPublishSentimentAnalysisTask:
"""Tests for publish_sentiment_analysis_task function."""
@patch('app.queues.producers.datetime')
def test_publish_sentiment_analysis_task_twitter(self, mock_datetime):
"""Test publishing a Twitter sentiment analysis task."""
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
mock_client = Mock()
mock_client.publish_message = Mock()
publish_sentiment_analysis_task(
client=mock_client,
match_id=123,
source='twitter',
entity_ids=['tweet1', 'tweet2', 'tweet3']
)
call_args = mock_client.publish_message.call_args
assert call_args[1]['queue_name'] == 'sentiment_analysis_tasks'
assert call_args[1]['event_type'] == 'sentiment_analysis.task.created'
task_data = call_args[1]['data']
assert task_data['task_type'] == 'sentiment_analysis'
assert task_data['match_id'] == 123
assert task_data['source'] == 'twitter'
assert task_data['entity_ids'] == ['tweet1', 'tweet2', 'tweet3']
assert task_data['texts'] == []
assert 'created_at' in task_data
@patch('app.queues.producers.datetime')
def test_publish_sentiment_analysis_task_with_texts(self, mock_datetime):
"""Test publishing sentiment analysis task with texts."""
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
mock_client = Mock()
mock_client.publish_message = Mock()
texts = ["Great match!", "Amazing goal", "What a game"]
publish_sentiment_analysis_task(
client=mock_client,
match_id=123,
source='reddit',
entity_ids=['post1', 'post2', 'post3'],
texts=texts
)
call_args = mock_client.publish_message.call_args
task_data = call_args[1]['data']
assert task_data['texts'] == texts
assert task_data['source'] == 'reddit'
class TestPublishEnergyCalculationTask:
"""Tests for publish_energy_calculation_task function."""
@patch('app.queues.producers.datetime')
def test_publish_energy_calculation_task(self, mock_datetime):
"""Test publishing an energy calculation task."""
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
mock_client = Mock()
mock_client.publish_message = Mock()
twitter_sentiments = [
{'compound': 0.5, 'sentiment': 'positive'},
{'compound': -0.3, 'sentiment': 'negative'}
]
publish_energy_calculation_task(
client=mock_client,
match_id=123,
team_id=456,
twitter_sentiments=twitter_sentiments
)
call_args = mock_client.publish_message.call_args
assert call_args[1]['queue_name'] == 'energy_calculation_tasks'
assert call_args[1]['event_type'] == 'energy_calculation.task.created'
task_data = call_args[1]['data']
assert task_data['task_type'] == 'energy_calculation'
assert task_data['match_id'] == 123
assert task_data['team_id'] == 456
assert task_data['twitter_sentiments'] == twitter_sentiments
assert task_data['reddit_sentiments'] == []
assert task_data['rss_sentiments'] == []
assert task_data['tweets_with_timestamps'] == []
assert 'created_at' in task_data
class TestPublishResult:
"""Tests for publish_result function."""
@patch('app.queues.producers.datetime')
def test_publish_scraping_result_wrapper(self, mock_datetime):
"""Test publishing a scraping result."""
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
mock_client = Mock()
mock_client.publish_message = Mock()
publish_scraping_result(
client=mock_client,
match_id=123,
source='twitter',
collected_count=100,
metadata={'keywords': ['#MatchName']}
)
call_args = mock_client.publish_message.call_args
assert call_args[1]['queue_name'] == 'results'
assert call_args[1]['event_type'] == 'result.published'
result_data = call_args[1]['data']
assert result_data['result_type'] == 'scraping'
assert result_data['data']['match_id'] == 123
assert result_data['data']['source'] == 'twitter'
assert result_data['data']['collected_count'] == 100
assert result_data['data']['status'] == 'success'
assert result_data['data']['metadata'] == {'keywords': ['#MatchName']}
@patch('app.queues.producers.datetime')
def test_publish_sentiment_analysis_result_wrapper(self, mock_datetime):
"""Test publishing a sentiment analysis result."""
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
mock_client = Mock()
mock_client.publish_message = Mock()
metrics = {
'total_count': 100,
'positive_count': 60,
'negative_count': 20,
'neutral_count': 20,
'average_compound': 0.35
}
publish_sentiment_analysis_result(
client=mock_client,
match_id=123,
source='twitter',
analyzed_count=100,
metrics=metrics
)
call_args = mock_client.publish_message.call_args
result_data = call_args[1]['data']
assert result_data['result_type'] == 'sentiment'
assert result_data['data']['match_id'] == 123
assert result_data['data']['source'] == 'twitter'
assert result_data['data']['analyzed_count'] == 100
assert result_data['data']['metrics'] == metrics
assert result_data['data']['status'] == 'success'
@patch('app.queues.producers.datetime')
def test_publish_energy_calculation_result_wrapper(self, mock_datetime):
"""Test publishing an energy calculation result."""
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
mock_client = Mock()
mock_client.publish_message = Mock()
publish_energy_calculation_result(
client=mock_client,
match_id=123,
team_id=456,
energy_score=78.5,
confidence=0.82,
sources_used=['twitter', 'reddit']
)
call_args = mock_client.publish_message.call_args
result_data = call_args[1]['data']
assert result_data['result_type'] == 'energy'
assert result_data['data']['match_id'] == 123
assert result_data['data']['team_id'] == 456
assert result_data['data']['energy_score'] == 78.5
assert result_data['data']['confidence'] == 0.82
assert result_data['data']['sources_used'] == ['twitter', 'reddit']
assert result_data['data']['status'] == 'success'

View File

@@ -0,0 +1,381 @@
"""
Unit tests for Reddit scraper.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, patch, MagicMock
import praw.exceptions
from app.scrapers.reddit_scraper import (
RedditScraper,
RedditPostData,
RedditCommentData,
create_reddit_scraper
)
@pytest.fixture
def mock_praw():
"""Mock praw.Reddit."""
with patch('app.scrapers.reddit_scraper.praw.Reddit') as mock:
yield mock
@pytest.fixture
def test_client_id():
"""Test Reddit client ID."""
return "test_client_id"
@pytest.fixture
def test_client_secret():
"""Test Reddit client secret."""
return "test_client_secret"
@pytest.fixture
def test_scraper(test_client_id, test_client_secret):
"""Create test scraper instance."""
with patch('app.scrapers.reddit_scraper.praw.Reddit'):
scraper = RedditScraper(
client_id=test_client_id,
client_secret=test_client_secret,
subreddits=["soccer", "football"],
max_posts_per_subreddit=100,
max_comments_per_post=50
)
return scraper
class TestRedditPostData:
"""Test RedditPostData dataclass."""
def test_reddit_post_data_creation(self):
"""Test creating RedditPostData instance."""
post = RedditPostData(
post_id="abc123",
title="Test post title",
text="Test post content",
upvotes=100,
created_at=datetime.now(timezone.utc),
match_id=1,
subreddit="soccer"
)
assert post.post_id == "abc123"
assert post.title == "Test post title"
assert post.text == "Test post content"
assert post.upvotes == 100
assert post.match_id == 1
assert post.subreddit == "soccer"
assert post.source == "reddit"
class TestRedditCommentData:
"""Test RedditCommentData dataclass."""
def test_reddit_comment_data_creation(self):
"""Test creating RedditCommentData instance."""
comment = RedditCommentData(
comment_id="def456",
post_id="abc123",
text="Test comment content",
upvotes=50,
created_at=datetime.now(timezone.utc)
)
assert comment.comment_id == "def456"
assert comment.post_id == "abc123"
assert comment.text == "Test comment content"
assert comment.upvotes == 50
assert comment.source == "reddit"
class TestRedditScraper:
"""Test RedditScraper class."""
def test_scraper_initialization(self, test_client_id, test_client_secret):
"""Test scraper initialization."""
with patch('app.scrapers.reddit_scraper.praw.Reddit'):
scraper = RedditScraper(
client_id=test_client_id,
client_secret=test_client_secret,
subreddits=["soccer", "football"],
max_posts_per_subreddit=100,
max_comments_per_post=50
)
assert scraper.client_id == test_client_id
assert scraper.client_secret == test_client_secret
assert scraper.subreddits == ["soccer", "football"]
assert scraper.max_posts_per_subreddit == 100
assert scraper.max_comments_per_post == 50
def test_verify_authentication_success(self, mock_praw, caplog):
"""Test successful authentication verification."""
mock_reddit = Mock()
mock_reddit.user.me.return_value = Mock(name="test_user")
mock_praw.return_value = mock_reddit
with caplog.at_level("INFO"):
scraper = RedditScraper(
client_id="test_id",
client_secret="test_secret",
subreddits=["soccer"]
)
assert "✅ Reddit API authenticated successfully" in caplog.text
def test_verify_authentication_failure(self, mock_praw, caplog):
"""Test failed authentication verification."""
mock_reddit = Mock()
mock_reddit.user.me.side_effect = Exception("Auth failed")
mock_praw.return_value = mock_reddit
with pytest.raises(Exception, match="Reddit API authentication failed"):
RedditScraper(
client_id="invalid_id",
client_secret="invalid_secret",
subreddits=["soccer"]
)
def test_scrape_posts_empty(self, test_scraper, mock_praw, caplog):
"""Test scraping posts with no results."""
mock_subreddit = Mock()
mock_subreddit.new.return_value = []
mock_reddit = Mock()
mock_reddit.subreddit.return_value = mock_subreddit
mock_praw.return_value = mock_reddit
with caplog.at_level("INFO"):
result = test_scraper.scrape_posts(
subreddit="soccer",
match_id=1,
keywords=["test"]
)
assert result == []
assert " No posts found" in caplog.text
def test_scrape_posts_success(self, test_scraper, mock_praw):
"""Test successful post scraping."""
# Mock post
mock_post = Mock()
mock_post.id = "abc123"
mock_post.title = "Test match discussion"
mock_post.selftext = "Great match today!"
mock_post.score = 100
mock_post.created_utc = 1700000000.0
mock_subreddit = Mock()
mock_subreddit.new.return_value = [mock_post]
mock_reddit = Mock()
mock_reddit.subreddit.return_value = mock_subreddit
mock_praw.return_value = mock_reddit
result = test_scraper.scrape_posts(
subreddit="soccer",
match_id=1,
keywords=["test"]
)
assert len(result) == 1
assert result[0].post_id == "abc123"
assert result[0].title == "Test match discussion"
assert result[0].upvotes == 100
def test_scrape_comments_success(self, test_scraper, mock_praw):
"""Test successful comment scraping."""
# Mock comment
mock_comment = Mock()
mock_comment.id = "def456"
mock_comment.body = "Great goal!"
mock_comment.score = 50
mock_comment.created_utc = 1700000000.0
mock_post = Mock()
mock_post.comments.list.return_value = [mock_comment]
result = test_scraper.scrape_comments(post_id="abc123", post=mock_post)
assert len(result) == 1
assert result[0].comment_id == "def456"
assert result[0].text == "Great goal!"
assert result[0].upvotes == 50
def test_save_posts_to_db(self, test_scraper, mock_praw):
"""Test saving posts to database."""
from app.models.reddit_post import RedditPost
# Mock posts data
posts_data = [
RedditPostData(
post_id="abc123",
title="Test post",
text="Test content",
upvotes=100,
created_at=datetime.now(timezone.utc),
match_id=1,
subreddit="soccer"
)
]
# Mock database session
mock_db = Mock()
mock_db.query.return_value.filter.return_value.first.return_value = None
# Mock model
mock_post_instance = Mock()
mock_post_instance.id = 1
with patch('app.scrapers.reddit_scraper.RedditPost', return_value=mock_post_instance):
test_scraper.save_posts_to_db(posts_data, mock_db)
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
def test_save_comments_to_db(self, test_scraper):
"""Test saving comments to database."""
# Mock comments data
comments_data = [
RedditCommentData(
comment_id="def456",
post_id="abc123",
text="Test comment",
upvotes=50,
created_at=datetime.now(timezone.utc)
)
]
# Mock database session
mock_db = Mock()
mock_db.query.return_value.filter.return_value.first.return_value = None
# Mock model
mock_comment_instance = Mock()
mock_comment_instance.id = 1
with patch('app.scrapers.reddit_scraper.RedditComment', return_value=mock_comment_instance):
test_scraper.save_comments_to_db(comments_data, mock_db)
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
def test_save_posts_to_db_existing(self, test_scraper):
"""Test saving posts when post already exists."""
posts_data = [
RedditPostData(
post_id="abc123",
title="Test post",
text="Test content",
upvotes=100,
created_at=datetime.now(timezone.utc),
match_id=1,
subreddit="soccer"
)
]
mock_db = Mock()
mock_db.query.return_value.filter.return_value.first.return_value = Mock()
test_scraper.save_posts_to_db(posts_data, mock_db)
# Should not add or commit if post already exists
mock_db.add.assert_not_called()
mock_db.commit.assert_called_once()
def test_scrape_posts_api_error_continues(self, test_scraper, mock_praw, caplog):
"""Test that scraper continues on API error."""
mock_subreddit = Mock()
mock_subreddit.new.side_effect = praw.exceptions.PRAWException("API Error")
mock_reddit = Mock()
mock_reddit.subreddit.return_value = mock_subreddit
mock_praw.return_value = mock_reddit
with caplog.at_level("ERROR"):
result = test_scraper.scrape_posts(
subreddit="soccer",
match_id=1,
keywords=["test"]
)
assert result == []
assert "Reddit API error" in caplog.text
def test_scrape_comments_api_error_continues(self, test_scraper, caplog):
"""Test that scraper continues on comment API error."""
mock_post = Mock()
mock_post.comments.list.side_effect = praw.exceptions.PRAWException("API Error")
with caplog.at_level("ERROR"):
result = test_scraper.scrape_comments(post_id="abc123", post=mock_post)
assert result == []
assert "Reddit API error" in caplog.text
def test_save_posts_db_error_rollback(self, test_scraper, caplog):
"""Test that database errors trigger rollback."""
posts_data = [
RedditPostData(
post_id="abc123",
title="Test post",
text="Test content",
upvotes=100,
created_at=datetime.now(timezone.utc),
match_id=1,
subreddit="soccer"
)
]
mock_db = Mock()
mock_db.query.return_value.filter.return_value.first.return_value = None
mock_db.commit.side_effect = Exception("Database error")
with caplog.at_level("ERROR"):
with pytest.raises(Exception):
test_scraper.save_posts_to_db(posts_data, mock_db)
mock_db.rollback.assert_called_once()
assert "Failed to save Reddit posts" in caplog.text
def test_scrape_reddit_match_continues_on_subreddit_error(
self, test_scraper, mock_praw, caplog
):
"""Test that scraper continues with other subreddits on error."""
# Mock first subreddit to error, second to succeed
mock_subreddit1 = Mock()
mock_subreddit1.new.side_effect = praw.exceptions.PRAWException("API Error")
mock_subreddit2 = Mock()
mock_subreddit2.new.return_value = []
mock_reddit = Mock()
mock_reddit.subreddit.side_effect = [mock_subreddit1, mock_subreddit2]
mock_praw.return_value = mock_reddit
test_scraper.subreddits = ["soccer", "football"]
with caplog.at_level("ERROR"):
result = test_scraper.scrape_reddit_match(match_id=1)
assert result['posts'] == []
assert result['comments'] == []
assert "Continuing with other sources" in caplog.text
class TestCreateRedditScraper:
"""Test create_reddit_scraper factory function."""
def test_factory_function(self, test_client_id, test_client_secret):
"""Test factory function creates scraper."""
with patch('app.scrapers.reddit_scraper.RedditScraper') as MockScraper:
mock_instance = Mock()
MockScraper.return_value = mock_instance
result = create_reddit_scraper(
client_id=test_client_id,
client_secret=test_client_secret
)
MockScraper.assert_called_once()
assert result == mock_instance

View File

@@ -0,0 +1,369 @@
"""
Tests for RSS scraper module.
This test suite validates the RSS scraper functionality including
parsing, filtering, error handling, and database operations.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, patch, MagicMock
from app.scrapers.rss_scraper import (
RSSScraper,
RSSArticleData,
create_rss_scraper
)
class TestRSSScraperInit:
"""Tests for RSS scraper initialization."""
def test_init_default_sources(self):
"""Test initialization with default RSS sources."""
scraper = RSSScraper()
assert len(scraper.rss_sources) == 4
assert scraper.timeout == 30
assert scraper.max_articles_per_source == 100
assert len(scraper.keywords) > 0
def test_init_custom_sources(self):
"""Test initialization with custom RSS sources."""
custom_sources = ["http://example.com/rss"]
scraper = RSSScraper(rss_sources=custom_sources)
assert scraper.rss_sources == custom_sources
assert len(scraper.rss_sources) == 1
def test_init_custom_keywords(self):
"""Test initialization with custom keywords."""
custom_keywords = ["football", "soccer"]
scraper = RSSScraper(keywords=custom_keywords)
assert scraper.keywords == custom_keywords
def test_init_custom_timeout(self):
"""Test initialization with custom timeout."""
scraper = RSSScraper(timeout=60)
assert scraper.timeout == 60
class TestRSSScraperIsArticleRelevant:
"""Tests for article relevance filtering."""
def test_relevant_article_with_keyword(self):
"""Test that article with keyword is relevant."""
scraper = RSSScraper()
title = "Arsenal wins Premier League match"
content = "Great performance by the team"
assert scraper._is_article_relevant(title, content) is True
def test_relevant_article_multiple_keywords(self):
"""Test article with multiple keywords."""
scraper = RSSScraper()
title = "Champions League: Real Madrid vs Barcelona"
content = "Soccer match preview"
assert scraper._is_article_relevant(title, content) is True
def test_irrelevant_article(self):
"""Test that irrelevant article is filtered out."""
scraper = RSSScraper()
title = "Technology news: New iPhone released"
content = "Apple announced new products"
assert scraper._is_article_relevant(title, content) is False
def test_case_insensitive_matching(self):
"""Test that keyword matching is case insensitive."""
scraper = RSSScraper()
title = "FOOTBALL MATCH: TEAM A VS TEAM B"
content = "SOCCER game details"
assert scraper._is_article_relevant(title, content) is True
class TestRSSScraperParsePublishedDate:
"""Tests for date parsing."""
def test_parse_valid_date(self):
"""Test parsing a valid date string."""
scraper = RSSScraper()
date_str = "Sat, 15 Jan 2026 10:30:00 +0000"
parsed = scraper._parse_published_date(date_str)
assert isinstance(parsed, datetime)
assert parsed.tzinfo is not None
def test_parse_invalid_date(self):
"""Test parsing an invalid date string falls back to current time."""
scraper = RSSScraper()
date_str = "invalid-date"
parsed = scraper._parse_published_date(date_str)
assert isinstance(parsed, datetime)
assert parsed.tzinfo is not None
class TestRSSScraperParseFeed:
"""Tests for RSS feed parsing."""
@patch('feedparser.parse')
def test_parse_valid_feed(self, mock_parse):
"""Test parsing a valid RSS feed."""
# Mock feedparser response
mock_feed = Mock()
mock_feed.feed = {'title': 'ESPN'}
mock_feed.bozo = False
mock_feed.entries = [
Mock(
id='article-1',
title='Football match preview',
summary='Team A vs Team B',
published='Sat, 15 Jan 2026 10:30:00 +0000',
link='http://example.com/article-1'
)
]
mock_parse.return_value = mock_feed
scraper = RSSScraper()
articles = scraper._parse_feed('http://example.com/rss')
assert len(articles) >= 0
mock_parse.assert_called_once()
@patch('feedparser.parse')
def test_parse_feed_with_bozo_error(self, mock_parse):
"""Test parsing a feed with XML errors."""
# Mock feedparser response with bozo error
mock_feed = Mock()
mock_feed.feed = {'title': 'ESPN'}
mock_feed.bozo = True
mock_feed.entries = []
mock_parse.return_value = mock_feed
scraper = RSSScraper()
articles = scraper._parse_feed('http://example.com/rss')
# Should not crash, but log warning
assert isinstance(articles, list)
@patch('feedparser.parse')
def test_parse_feed_filters_irrelevant_articles(self, mock_parse):
"""Test that irrelevant articles are filtered out."""
# Mock feedparser response
mock_feed = Mock()
mock_feed.feed = {'title': 'ESPN'}
mock_feed.bozo = False
mock_feed.entries = [
Mock(
id='article-1',
title='Football news',
summary='Match result',
published='Sat, 15 Jan 2026 10:30:00 +0000',
link='http://example.com/article-1'
),
Mock(
id='article-2',
title='Technology news',
summary='New iPhone',
published='Sat, 15 Jan 2026 11:30:00 +0000',
link='http://example.com/article-2'
)
]
mock_parse.return_value = mock_feed
scraper = RSSScraper()
articles = scraper._parse_feed('http://example.com/rss')
# Only football article should be included
football_articles = [a for a in articles if 'football' in a.title.lower()]
assert len(football_articles) >= 0
class TestRSSScraperScrapeAllSources:
"""Tests for scraping all sources."""
@patch('app.scrapers.rss_scraper.RSSScraper._parse_feed')
def test_scrape_all_sources(self, mock_parse_feed):
"""Test scraping all configured sources."""
# Mock feed parsing
mock_parse_feed.return_value = [
RSSArticleData(
article_id='article-1',
title='Football news',
content='Match result',
published_at=datetime.now(timezone.utc),
source_url='http://example.com/rss',
match_id=None,
source='ESPN'
)
]
scraper = RSSScraper()
articles = scraper.scrape_all_sources()
assert len(articles) >= 0
assert mock_parse_feed.call_count == len(scraper.rss_sources)
@patch('app.scrapers.rss_scraper.RSSScraper._parse_feed')
def test_scrape_all_sources_with_match_id(self, mock_parse_feed):
"""Test scraping with match ID."""
mock_parse_feed.return_value = [
RSSArticleData(
article_id='article-1',
title='Football news',
content='Match result',
published_at=datetime.now(timezone.utc),
source_url='http://example.com/rss',
match_id=None,
source='ESPN'
)
]
scraper = RSSScraper()
match_id = 123
articles = scraper.scrape_all_sources(match_id=match_id)
# All articles should have match_id set
for article in articles:
assert article.match_id == match_id
@patch('app.scrapers.rss_scraper.RSSScraper._parse_feed')
def test_scrape_all_sources_continues_on_error(self, mock_parse_feed):
"""Test that scraper continues on source errors."""
# Make second source raise an error
mock_parse_feed.side_effect = [
[], # First source succeeds
Exception("Network error"), # Second source fails
[] # Third source succeeds
]
scraper = RSSScraper()
articles = scraper.scrape_all_sources()
# Should have collected from successful sources
assert isinstance(articles, list)
class TestRSSScraperSaveToDatabase:
"""Tests for saving articles to database."""
@patch('app.scrapers.rss_scraper.Session')
def test_save_articles_to_db(self, mock_session_class):
"""Test saving articles to database."""
mock_db = Mock()
mock_session_class.return_value = mock_db
# Mock query to return None (article doesn't exist)
mock_db.query.return_value.filter.return_value.first.return_value = None
scraper = RSSScraper()
articles = [
RSSArticleData(
article_id='article-1',
title='Football news',
content='Match result',
published_at=datetime.now(timezone.utc),
source_url='http://example.com/rss',
match_id=None,
source='ESPN'
)
]
scraper.save_articles_to_db(articles, mock_db)
# Verify commit was called
mock_db.commit.assert_called_once()
@patch('app.scrapers.rss_scraper.Session')
def test_save_articles_skips_duplicates(self, mock_session_class):
"""Test that duplicate articles are skipped."""
mock_db = Mock()
mock_session_class.return_value = mock_db
# Mock query to return existing article
mock_existing = Mock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_existing
scraper = RSSScraper()
articles = [
RSSArticleData(
article_id='article-1',
title='Football news',
content='Match result',
published_at=datetime.now(timezone.utc),
source_url='http://example.com/rss',
match_id=None,
source='ESPN'
)
]
scraper.save_articles_to_db(articles, mock_db)
# Should not add duplicate, but still commit
assert mock_db.add.call_count == 0
mock_db.commit.assert_called_once()
@patch('app.scrapers.rss_scraper.Session')
def test_save_articles_handles_db_error(self, mock_session_class):
"""Test that database errors are handled properly."""
mock_db = Mock()
mock_session_class.return_value = mock_db
# Mock query and commit to raise error
mock_db.query.return_value.filter.return_value.first.return_value = None
mock_db.commit.side_effect = Exception("Database error")
scraper = RSSScraper()
articles = [
RSSArticleData(
article_id='article-1',
title='Football news',
content='Match result',
published_at=datetime.now(timezone.utc),
source_url='http://example.com/rss',
match_id=None,
source='ESPN'
)
]
# Should raise exception
with pytest.raises(Exception):
scraper.save_articles_to_db(articles, mock_db)
# Verify rollback was called
mock_db.rollback.assert_called_once()
class TestCreateRSSScraper:
"""Tests for factory function."""
def test_create_default_scraper(self):
"""Test creating a scraper with default config."""
scraper = create_rss_scraper()
assert isinstance(scraper, RSSScraper)
assert len(scraper.rss_sources) > 0
def test_create_custom_scraper(self):
"""Test creating a scraper with custom config."""
custom_sources = ["http://custom.com/rss"]
custom_keywords = ["football"]
scraper = create_rss_scraper(
rss_sources=custom_sources,
keywords=custom_keywords
)
assert scraper.rss_sources == custom_sources
assert scraper.keywords == custom_keywords

View File

@@ -0,0 +1,246 @@
"""
Tests for scraping worker.
"""
import pytest
from unittest.mock import Mock, patch
from sqlalchemy.orm import Session
from app.workers.scraping_worker import (
ScrapingWorker,
create_scraping_worker
)
class TestScrapingWorker:
"""Tests for ScrapingWorker class."""
def test_initialization(self):
"""Test scraping worker initialization."""
worker = ScrapingWorker(
twitter_bearer_token="test_token",
reddit_client_id="test_id",
reddit_client_secret="test_secret"
)
assert worker.twitter_bearer_token == "test_token"
assert worker.reddit_client_id == "test_id"
assert worker.reddit_client_secret == "test_secret"
assert worker.twitter_scraper is None
assert worker.reddit_scraper is None
def test_execute_scraping_task_twitter(self):
"""Test executing a Twitter scraping task."""
# Create worker
worker = ScrapingWorker(
twitter_bearer_token="test_token",
reddit_client_id="test_id",
reddit_client_secret="test_secret"
)
# Mock Twitter scraper
mock_twitter_scraper = Mock()
worker.twitter_scraper = mock_twitter_scraper
mock_twitter_scraper.scrape_and_save.return_value = [Mock()] * 50
# Mock database session
mock_db = Mock(spec=Session)
# Execute task
task = {
'match_id': 123,
'source': 'twitter',
'keywords': ['#MatchName'],
'priority': 'normal'
}
result = worker.execute_scraping_task(task, mock_db)
# Verify scraping called
mock_twitter_scraper.scrape_and_save.assert_called_once_with(
match_id=123,
keywords=['#MatchName'],
db=mock_db,
max_results=100
)
# Verify result
assert result['collected_count'] == 50
assert result['status'] == 'success'
assert result['metadata']['source'] == 'twitter'
assert result['metadata']['match_id'] == 123
def test_execute_scraping_task_reddit(self):
"""Test executing a Reddit scraping task."""
# Create worker
worker = ScrapingWorker(
twitter_bearer_token="test_token",
reddit_client_id="test_id",
reddit_client_secret="test_secret"
)
# Mock Reddit scraper
mock_reddit_scraper = Mock()
worker.reddit_scraper = mock_reddit_scraper
mock_reddit_scraper.scrape_and_save.return_value = {
'posts': [Mock()] * 20,
'comments': [Mock()] * 30
}
# Mock database session
mock_db = Mock(spec=Session)
# Execute task
task = {
'match_id': 456,
'source': 'reddit',
'keywords': ['Ligue1'],
'priority': 'vip'
}
result = worker.execute_scraping_task(task, mock_db)
# Verify scraping called
mock_reddit_scraper.scrape_and_save.assert_called_once_with(
match_id=456,
db=mock_db,
keywords=['Ligue1'],
scrape_comments=True
)
# Verify result
assert result['collected_count'] == 50 # 20 posts + 30 comments
assert result['status'] == 'success'
assert result['metadata']['source'] == 'reddit'
assert result['metadata']['match_id'] == 456
assert result['metadata']['posts_count'] == 20
assert result['metadata']['comments_count'] == 30
def test_execute_scraping_task_unknown_source(self):
"""Test executing task with unknown source."""
worker = ScrapingWorker(
twitter_bearer_token="test_token",
reddit_client_id="test_id",
reddit_client_secret="test_secret"
)
mock_db = Mock(spec=Session)
# Execute task with unknown source
task = {
'match_id': 123,
'source': 'unknown',
'keywords': ['#MatchName']
}
result = worker.execute_scraping_task(task, mock_db)
# Verify error result
assert result['collected_count'] == 0
assert result['status'] == 'error'
assert 'error' in result
assert 'Unknown source' in result['error']
def test_execute_scraping_task_twitter_error(self):
"""Test handling Twitter scraping errors."""
worker = ScrapingWorker(
twitter_bearer_token="test_token",
reddit_client_id="test_id",
reddit_client_secret="test_secret"
)
# Mock Twitter scraper with error
mock_twitter_scraper = Mock()
worker.twitter_scraper = mock_twitter_scraper
mock_twitter_scraper.scrape_and_save.side_effect = Exception("API Error")
mock_db = Mock(spec=Session)
# Execute task
task = {
'match_id': 123,
'source': 'twitter',
'keywords': ['#MatchName']
}
result = worker.execute_scraping_task(task, mock_db)
# Verify error handling
assert result['collected_count'] == 0
assert result['status'] == 'error'
assert 'error' in result
@patch('app.workers.scraping_worker.create_twitter_scraper')
def test_get_twitter_scraper_lazy_initialization(self, mock_create_scraper):
"""Test lazy initialization of Twitter scraper."""
worker = ScrapingWorker(
twitter_bearer_token="test_token",
reddit_client_id="test_id",
reddit_client_secret="test_secret"
)
# First call should create scraper
mock_scraper_instance = Mock()
mock_create_scraper.return_value = mock_scraper_instance
scraper1 = worker._get_twitter_scraper()
# Verify creation
mock_create_scraper.assert_called_once_with(
bearer_token="test_token",
vip_match_ids=[]
)
assert scraper1 == mock_scraper_instance
# Second call should return same instance
scraper2 = worker._get_twitter_scraper()
assert scraper2 == scraper1
# Verify not created again
assert mock_create_scraper.call_count == 1
@patch('app.workers.scraping_worker.create_reddit_scraper')
def test_get_reddit_scraper_lazy_initialization(self, mock_create_scraper):
"""Test lazy initialization of Reddit scraper."""
worker = ScrapingWorker(
twitter_bearer_token="test_token",
reddit_client_id="test_id",
reddit_client_secret="test_secret"
)
# First call should create scraper
mock_scraper_instance = Mock()
mock_create_scraper.return_value = mock_scraper_instance
scraper1 = worker._get_reddit_scraper()
# Verify creation
mock_create_scraper.assert_called_once_with(
client_id="test_id",
client_secret="test_secret"
)
assert scraper1 == mock_scraper_instance
# Second call should return same instance
scraper2 = worker._get_reddit_scraper()
assert scraper2 == scraper1
# Verify not created again
assert mock_create_scraper.call_count == 1
class TestCreateScrapingWorker:
"""Tests for create_scraping_worker factory function."""
def test_create_scraping_worker(self):
"""Test creating a scraping worker."""
worker = create_scraping_worker(
twitter_bearer_token="token123",
reddit_client_id="id456",
reddit_client_secret="secret789"
)
assert isinstance(worker, ScrapingWorker)
assert worker.twitter_bearer_token == "token123"
assert worker.reddit_client_id == "id456"
assert worker.reddit_client_secret == "secret789"

View File

@@ -0,0 +1,201 @@
"""
Tests for Sentiment Analyzer
This module tests the VADER sentiment analyzer functionality.
"""
import pytest
from app.ml.sentiment_analyzer import (
analyze_sentiment,
analyze_sentiment_batch,
calculate_aggregated_metrics,
classify_sentiment,
test_analyzer_performance
)
class TestClassifySentiment:
"""Test sentiment classification based on compound score."""
def test_positive_classification(self):
"""Test classification of positive sentiment."""
assert classify_sentiment(0.5) == 'positive'
assert classify_sentiment(0.05) == 'positive'
assert classify_sentiment(1.0) == 'positive'
def test_negative_classification(self):
"""Test classification of negative sentiment."""
assert classify_sentiment(-0.5) == 'negative'
assert classify_sentiment(-0.05) == 'negative'
assert classify_sentiment(-1.0) == 'negative'
def test_neutral_classification(self):
"""Test classification of neutral sentiment."""
assert classify_sentiment(0.0) == 'neutral'
assert classify_sentiment(0.04) == 'neutral'
assert classify_sentiment(-0.04) == 'neutral'
class TestAnalyzeSentiment:
"""Test single text sentiment analysis."""
def test_positive_text(self):
"""Test analysis of positive text."""
text = "I love this game! It's absolutely amazing and wonderful!"
result = analyze_sentiment(text)
assert result['sentiment'] == 'positive'
assert result['compound'] > 0.05
assert 0 <= result['positive'] <= 1
assert 0 <= result['negative'] <= 1
assert 0 <= result['neutral'] <= 1
def test_negative_text(self):
"""Test analysis of negative text."""
text = "This is terrible! I hate it and it's the worst ever."
result = analyze_sentiment(text)
assert result['sentiment'] == 'negative'
assert result['compound'] < -0.05
assert 0 <= result['positive'] <= 1
assert 0 <= result['negative'] <= 1
assert 0 <= result['neutral'] <= 1
def test_neutral_text(self):
"""Test analysis of neutral text."""
text = "The game is okay. Nothing special but nothing bad."
result = analyze_sentiment(text)
assert result['sentiment'] in ['positive', 'negative', 'neutral']
assert 0 <= result['compound'] <= 1
assert 0 <= result['positive'] <= 1
assert 0 <= result['negative'] <= 1
assert 0 <= result['neutral'] <= 1
def test_empty_text_raises_error(self):
"""Test that empty text raises ValueError."""
with pytest.raises(ValueError):
analyze_sentiment("")
def test_invalid_input_raises_error(self):
"""Test that invalid input raises ValueError."""
with pytest.raises(ValueError):
analyze_sentiment(None)
with pytest.raises(ValueError):
analyze_sentiment(123)
class TestAnalyzeSentimentBatch:
"""Test batch sentiment analysis."""
def test_batch_analysis_multiple_texts(self):
"""Test analysis of multiple texts."""
texts = [
"I love this!",
"This is terrible!",
"It's okay."
]
results = analyze_sentiment_batch(texts)
assert len(results) == 3
assert all('compound' in r for r in results)
assert all('positive' in r for r in results)
assert all('negative' in r for r in results)
assert all('neutral' in r for r in results)
assert all('sentiment' in r for r in results)
def test_batch_analysis_empty_list(self):
"""Test batch analysis with empty list."""
results = analyze_sentiment_batch([])
assert results == []
def test_batch_analysis_with_errors(self):
"""Test batch analysis handles errors gracefully."""
texts = [
"This is good!",
"", # Invalid - should be handled gracefully
"This is great!"
]
results = analyze_sentiment_batch(texts)
assert len(results) == 3
# The invalid text should return neutral sentiment
assert results[1]['sentiment'] == 'neutral'
assert results[1]['compound'] == 0.0
class TestCalculateAggregatedMetrics:
"""Test calculation of aggregated sentiment metrics."""
def test_aggregated_metrics_empty_list(self):
"""Test aggregated metrics with empty list."""
metrics = calculate_aggregated_metrics([])
assert metrics['total_count'] == 0
assert metrics['positive_count'] == 0
assert metrics['negative_count'] == 0
assert metrics['neutral_count'] == 0
assert metrics['positive_ratio'] == 0.0
assert metrics['negative_ratio'] == 0.0
assert metrics['neutral_ratio'] == 0.0
assert metrics['average_compound'] == 0.0
def test_aggregated_metrics_all_positive(self):
"""Test aggregated metrics with all positive sentiments."""
sentiments = [
{'compound': 0.5, 'sentiment': 'positive'},
{'compound': 0.7, 'sentiment': 'positive'},
{'compound': 0.9, 'sentiment': 'positive'}
]
metrics = calculate_aggregated_metrics(sentiments)
assert metrics['total_count'] == 3
assert metrics['positive_count'] == 3
assert metrics['negative_count'] == 0
assert metrics['neutral_count'] == 0
assert metrics['positive_ratio'] == 1.0
assert metrics['average_compound'] == pytest.approx(0.7, rel=0.01)
def test_aggregated_metrics_mixed(self):
"""Test aggregated metrics with mixed sentiments."""
sentiments = [
{'compound': 0.5, 'sentiment': 'positive'},
{'compound': -0.5, 'sentiment': 'negative'},
{'compound': 0.0, 'sentiment': 'neutral'},
{'compound': 0.7, 'sentiment': 'positive'}
]
metrics = calculate_aggregated_metrics(sentiments)
assert metrics['total_count'] == 4
assert metrics['positive_count'] == 2
assert metrics['negative_count'] == 1
assert metrics['neutral_count'] == 1
assert metrics['positive_ratio'] == 0.5
assert metrics['negative_ratio'] == 0.25
assert metrics['neutral_ratio'] == 0.25
assert metrics['average_compound'] == pytest.approx(0.175, rel=0.01)
class TestAnalyzerPerformance:
"""Test performance of the sentiment analyzer."""
def test_performance_1000_tweets(self):
"""Test that 1000 tweets can be analyzed in less than 1 second."""
time_taken = test_analyzer_performance(1000)
assert time_taken < 1.0, f"Analysis took {time_taken:.4f}s, expected < 1.0s"
def test_performance_5000_tweets(self):
"""Test performance with 5000 tweets."""
time_taken = test_analyzer_performance(5000)
# Should still be fast, but we allow some scaling
assert time_taken < 5.0, f"Analysis took {time_taken:.4f}s, expected < 5.0s"
def test_performance_10000_tweets(self):
"""Test performance with 10000 tweets."""
time_taken = test_analyzer_performance(10000)
# Should still be very fast
assert time_taken < 10.0, f"Analysis took {time_taken:.4f}s, expected < 10.0s"

View File

@@ -0,0 +1,304 @@
"""
Tests for Sentiment Service
This module tests sentiment service functionality including database operations.
"""
import pytest
from sqlalchemy.orm import Session
from datetime import datetime
from app.services.sentiment_service import (
process_tweet_sentiment,
process_tweet_batch,
process_reddit_post_sentiment,
process_reddit_post_batch,
get_sentiment_by_entity,
get_sentiments_by_match,
calculate_match_sentiment_metrics,
get_global_sentiment_metrics
)
from app.models.sentiment_score import SentimentScore
from app.models.tweet import Tweet
from app.models.reddit_post import RedditPost
@pytest.fixture
def sample_tweet(db: Session):
"""Create a sample tweet for testing."""
tweet = Tweet(
tweet_id="test_tweet_1",
text="This is a test tweet about a great game!",
created_at=datetime.utcnow(),
retweet_count=5,
like_count=10,
match_id=1,
source="twitter"
)
db.add(tweet)
db.commit()
db.refresh(tweet)
return tweet
@pytest.fixture
def sample_tweets(db: Session):
"""Create sample tweets for batch testing."""
tweets = []
texts = [
"I love this team! Best performance ever!",
"Terrible game today. Worst performance.",
"It was an okay match. Nothing special.",
"Amazing comeback! What a victory!",
"Disappointed with the result."
]
for i, text in enumerate(texts):
tweet = Tweet(
tweet_id=f"test_tweet_{i}",
text=text,
created_at=datetime.utcnow(),
retweet_count=i * 2,
like_count=i * 3,
match_id=1,
source="twitter"
)
db.add(tweet)
tweets.append(tweet)
db.commit()
for tweet in tweets:
db.refresh(tweet)
return tweets
@pytest.fixture
def sample_reddit_post(db: Session):
"""Create a sample Reddit post for testing."""
post = RedditPost(
post_id="test_post_1",
title="This is a test post about a great game!",
text="The team played amazingly well today!",
upvotes=15,
created_at=datetime.utcnow(),
match_id=1,
subreddit="test_subreddit",
source="reddit"
)
db.add(post)
db.commit()
db.refresh(post)
return post
@pytest.fixture
def sample_reddit_posts(db: Session):
"""Create sample Reddit posts for batch testing."""
posts = []
titles = [
"Great game today!",
"Terrible performance",
"Okay match",
"Amazing victory",
"Disappointed result"
]
texts = [
"The team was amazing!",
"Worst game ever",
"Nothing special",
"What a comeback!",
"Not good enough"
]
for i, (title, text) in enumerate(zip(titles, texts)):
post = RedditPost(
post_id=f"test_post_{i}",
title=title,
text=text,
upvotes=i * 5,
created_at=datetime.utcnow(),
match_id=1,
subreddit="test_subreddit",
source="reddit"
)
db.add(post)
posts.append(post)
db.commit()
for post in posts:
db.refresh(post)
return posts
class TestProcessTweetSentiment:
"""Test processing single tweet sentiment."""
def test_process_tweet_sentiment_creates_record(self, db: Session, sample_tweet: Tweet):
"""Test that processing tweet sentiment creates a database record."""
sentiment = process_tweet_sentiment(db, sample_tweet.tweet_id, sample_tweet.text)
assert sentiment.id is not None
assert sentiment.entity_id == sample_tweet.tweet_id
assert sentiment.entity_type == 'tweet'
assert sentiment.sentiment_type in ['positive', 'negative', 'neutral']
assert -1 <= sentiment.score <= 1
assert 0 <= sentiment.positive <= 1
assert 0 <= sentiment.negative <= 1
assert 0 <= sentiment.neutral <= 1
def test_process_tweet_sentiment_positive(self, db: Session, sample_tweet: Tweet):
"""Test processing positive tweet sentiment."""
positive_text = "I absolutely love this team! Best performance ever!"
sentiment = process_tweet_sentiment(db, "test_pos", positive_text)
assert sentiment.sentiment_type == 'positive'
assert sentiment.score > 0.05
def test_process_tweet_sentiment_negative(self, db: Session, sample_tweet: Tweet):
"""Test processing negative tweet sentiment."""
negative_text = "This is terrible! Worst performance ever!"
sentiment = process_tweet_sentiment(db, "test_neg", negative_text)
assert sentiment.sentiment_type == 'negative'
assert sentiment.score < -0.05
class TestProcessTweetBatch:
"""Test batch processing of tweet sentiments."""
def test_process_tweet_batch(self, db: Session, sample_tweets: list[Tweet]):
"""Test processing multiple tweets in batch."""
sentiments = process_tweet_batch(db, sample_tweets)
assert len(sentiments) == 5
assert all(s.entity_type == 'tweet' for s in sentiments)
assert all(s.id is not None for s in sentiments)
def test_process_tweet_batch_empty(self, db: Session):
"""Test processing empty batch."""
sentiments = process_tweet_batch(db, [])
assert sentiments == []
def test_process_tweet_batch_sentiments_calculated(self, db: Session, sample_tweets: list[Tweet]):
"""Test that sentiments are correctly calculated for batch."""
sentiments = process_tweet_batch(db, sample_tweets)
# Check that at least one positive and one negative sentiment exists
sentiment_types = [s.sentiment_type for s in sentiments]
assert 'positive' in sentiment_types
assert 'negative' in sentiment_types
class TestProcessRedditPostSentiment:
"""Test processing single Reddit post sentiment."""
def test_process_reddit_post_sentiment_creates_record(self, db: Session, sample_reddit_post: RedditPost):
"""Test that processing Reddit post sentiment creates a database record."""
sentiment = process_reddit_post_sentiment(
db,
sample_reddit_post.post_id,
f"{sample_reddit_post.title} {sample_reddit_post.text}"
)
assert sentiment.id is not None
assert sentiment.entity_id == sample_reddit_post.post_id
assert sentiment.entity_type == 'reddit_post'
assert sentiment.sentiment_type in ['positive', 'negative', 'neutral']
class TestProcessRedditPostBatch:
"""Test batch processing of Reddit post sentiments."""
def test_process_reddit_post_batch(self, db: Session, sample_reddit_posts: list[RedditPost]):
"""Test processing multiple Reddit posts in batch."""
sentiments = process_reddit_post_batch(db, sample_reddit_posts)
assert len(sentiments) == 5
assert all(s.entity_type == 'reddit_post' for s in sentiments)
assert all(s.id is not None for s in sentiments)
def test_process_reddit_post_batch_empty(self, db: Session):
"""Test processing empty batch."""
sentiments = process_reddit_post_batch(db, [])
assert sentiments == []
class TestGetSentimentByEntity:
"""Test retrieving sentiment by entity."""
def test_get_sentiment_by_entity_found(self, db: Session, sample_tweet: Tweet):
"""Test retrieving existing sentiment by entity."""
process_tweet_sentiment(db, sample_tweet.tweet_id, sample_tweet.text)
sentiment = get_sentiment_by_entity(
db,
sample_tweet.tweet_id,
'tweet'
)
assert sentiment is not None
assert sentiment.entity_id == sample_tweet.tweet_id
def test_get_sentiment_by_entity_not_found(self, db: Session):
"""Test retrieving non-existent sentiment."""
sentiment = get_sentiment_by_entity(db, "nonexistent_id", "tweet")
assert sentiment is None
class TestGetSentimentsByMatch:
"""Test retrieving sentiments by match."""
def test_get_sentiments_by_match(self, db: Session, sample_tweets: list[Tweet]):
"""Test retrieving all sentiments for a match."""
process_tweet_batch(db, sample_tweets)
sentiments = get_sentiments_by_match(db, 1)
assert len(sentiments) == 5
assert all(s.entity_type == 'tweet' for s in sentiments)
def test_get_sentiments_by_match_empty(self, db: Session):
"""Test retrieving sentiments for match with no data."""
sentiments = get_sentiments_by_match(db, 999)
assert sentiments == []
class TestCalculateMatchSentimentMetrics:
"""Test calculating sentiment metrics for a match."""
def test_calculate_match_sentiment_metrics(self, db: Session, sample_tweets: list[Tweet]):
"""Test calculating metrics for a match."""
process_tweet_batch(db, sample_tweets)
metrics = calculate_match_sentiment_metrics(db, 1)
assert metrics['match_id'] == 1
assert metrics['total_count'] == 5
assert metrics['positive_count'] + metrics['negative_count'] + metrics['neutral_count'] == 5
assert metrics['positive_ratio'] + metrics['negative_ratio'] + metrics['neutral_ratio'] == 1.0
assert -1 <= metrics['average_compound'] <= 1
def test_calculate_match_sentiment_metrics_empty(self, db: Session):
"""Test calculating metrics for match with no data."""
metrics = calculate_match_sentiment_metrics(db, 999)
assert metrics['match_id'] == 999
assert metrics['total_count'] == 0
assert metrics['average_compound'] == 0.0
class TestGetGlobalSentimentMetrics:
"""Test calculating global sentiment metrics."""
def test_get_global_sentiment_metrics(self, db: Session, sample_tweets: list[Tweet]):
"""Test calculating global metrics."""
process_tweet_batch(db, sample_tweets)
metrics = get_global_sentiment_metrics(db)
assert metrics['total_count'] == 5
assert metrics['positive_count'] + metrics['negative_count'] + metrics['neutral_count'] == 5
assert metrics['positive_ratio'] + metrics['negative_ratio'] + metrics['neutral_ratio'] == 1.0
def test_get_global_sentiment_metrics_empty(self, db: Session):
"""Test calculating global metrics with no data."""
metrics = get_global_sentiment_metrics(db)
assert metrics['total_count'] == 0
assert metrics['average_compound'] == 0.0

View File

@@ -0,0 +1,251 @@
"""
Tests for sentiment analysis worker.
"""
import pytest
from unittest.mock import Mock, patch
from sqlalchemy.orm import Session
from app.workers.sentiment_worker import (
SentimentWorker,
create_sentiment_worker
)
class TestSentimentWorker:
"""Tests for SentimentWorker class."""
def test_initialization(self):
"""Test sentiment worker initialization."""
worker = SentimentWorker()
# No specific initialization required for sentiment worker
assert worker is not None
@patch('app.workers.sentiment_worker.process_tweet_batch')
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
def test_execute_sentiment_analysis_task_twitter(
self,
mock_get_sentiment,
mock_process_batch
):
"""Test executing a Twitter sentiment analysis task."""
# Create worker
worker = SentimentWorker()
# Mock database session
mock_db = Mock(spec=Session)
# Mock tweet query
mock_tweets = [Mock()] * 100
mock_db.query.return_value.filter.return_value.all.return_value = mock_tweets
# Mock existing sentiment (none exist)
mock_get_sentiment.return_value = None
# Mock batch processing
mock_process_batch.return_value = [Mock()] * 100
# Mock calculate metrics
with patch.object(
worker,
'_calculate_sentiment_metrics',
return_value={
'total_count': 100,
'positive_count': 60,
'negative_count': 20,
'neutral_count': 20,
'average_compound': 0.35
}
):
# Execute task
task = {
'match_id': 123,
'source': 'twitter',
'entity_ids': ['tweet1', 'tweet2']
}
result = worker.execute_sentiment_analysis_task(task, mock_db)
# Verify result
assert result['analyzed_count'] == 100
assert result['status'] == 'success'
assert result['metrics']['total_count'] == 100
assert result['metrics']['positive_count'] == 60
@patch('app.workers.sentiment_worker.process_reddit_post_batch')
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
def test_execute_sentiment_analysis_task_reddit(
self,
mock_get_sentiment,
mock_process_batch
):
"""Test executing a Reddit sentiment analysis task."""
# Create worker
worker = SentimentWorker()
# Mock database session
mock_db = Mock(spec=Session)
# Mock Reddit post query
mock_posts = [Mock()] * 50
mock_db.query.return_value.filter.return_value.all.return_value = mock_posts
# Mock existing sentiment (none exist)
mock_get_sentiment.return_value = None
# Mock batch processing
mock_process_batch.return_value = [Mock()] * 50
# Mock calculate metrics
with patch.object(
worker,
'_calculate_sentiment_metrics',
return_value={
'total_count': 50,
'positive_count': 25,
'negative_count': 15,
'neutral_count': 10,
'average_compound': 0.2
}
):
# Execute task
task = {
'match_id': 456,
'source': 'reddit',
'entity_ids': ['post1', 'post2']
}
result = worker.execute_sentiment_analysis_task(task, mock_db)
# Verify result
assert result['analyzed_count'] == 50
assert result['status'] == 'success'
assert result['metrics']['total_count'] == 50
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
def test_execute_sentiment_analysis_task_all_analyzed(
self,
mock_get_sentiment
):
"""Test executing task when all items already analyzed."""
worker = SentimentWorker()
mock_db = Mock(spec=Session)
# Mock tweets
mock_tweets = [Mock()] * 100
mock_db.query.return_value.filter.return_value.all.return_value = mock_tweets
# Mock existing sentiment (all exist)
mock_get_sentiment.return_value = Mock()
# Mock calculate metrics from existing
with patch.object(
worker,
'_calculate_metrics_from_existing',
return_value={
'total_count': 100,
'average_compound': 0.4
}
):
# Execute task
task = {
'match_id': 123,
'source': 'twitter',
'entity_ids': ['tweet1', 'tweet2']
}
result = worker.execute_sentiment_analysis_task(task, mock_db)
# Verify result (0 new items analyzed)
assert result['analyzed_count'] == 0
assert result['status'] == 'success'
def test_execute_sentiment_analysis_task_unknown_source(self):
"""Test executing task with unknown source."""
worker = SentimentWorker()
mock_db = Mock(spec=Session)
# Execute task with unknown source
task = {
'match_id': 123,
'source': 'unknown',
'entity_ids': ['item1']
}
result = worker.execute_sentiment_analysis_task(task, mock_db)
# Verify error result
assert result['analyzed_count'] == 0
assert result['status'] == 'error'
assert 'error' in result
assert 'Unknown source' in result['error']
@patch('app.workers.sentiment_worker.process_tweet_batch')
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
def test_execute_sentiment_analysis_task_error_handling(
self,
mock_get_sentiment,
mock_process_batch
):
"""Test error handling in sentiment analysis."""
worker = SentimentWorker()
mock_db = Mock(spec=Session)
# Mock tweets
mock_tweets = [Mock()] * 10
mock_db.query.return_value.filter.return_value.all.return_value = mock_tweets
# Mock no existing sentiment
mock_get_sentiment.return_value = None
# Mock processing error
mock_process_batch.side_effect = Exception("Processing error")
# Execute task
task = {
'match_id': 123,
'source': 'twitter',
'entity_ids': ['tweet1', 'tweet2']
}
result = worker.execute_sentiment_analysis_task(task, mock_db)
# Verify error handling
assert result['analyzed_count'] == 0
assert result['status'] == 'error'
assert 'error' in result
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
def test_execute_twitter_sentiment_analysis_no_tweets(
self,
mock_get_sentiment
):
"""Test Twitter sentiment analysis when no tweets found."""
worker = SentimentWorker()
mock_db = Mock(spec=Session)
# Mock no tweets
mock_db.query.return_value.filter.return_value.all.return_value = []
# Execute task
result = worker._execute_twitter_sentiment_analysis(
match_id=123,
entity_ids=['tweet1', 'tweet2'],
db=mock_db
)
# Verify result
assert result['analyzed_count'] == 0
assert result['status'] == 'success'
assert result['metrics']['total_count'] == 0
class TestCreateSentimentWorker:
"""Tests for create_sentiment_worker factory function."""
def test_create_sentiment_worker(self):
"""Test creating a sentiment worker."""
worker = create_sentiment_worker()
assert isinstance(worker, SentimentWorker)

View File

@@ -0,0 +1,140 @@
"""
Unit tests for Tweet model.
"""
import pytest
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from app.models.tweet import Tweet
from app.database import Base, engine, SessionLocal
@pytest.fixture(scope="function")
def db_session():
"""Create a fresh database session for each test."""
# Create tables
Base.metadata.create_all(bind=engine)
# Create session
session = SessionLocal()
try:
yield session
session.rollback()
finally:
session.close()
Base.metadata.drop_all(bind=engine)
class TestTweetModel:
"""Test Tweet SQLAlchemy model."""
def test_tweet_creation(self, db_session: Session):
"""Test creating a tweet in the database."""
tweet = Tweet(
tweet_id="123456789",
text="Test tweet content",
created_at=datetime.now(timezone.utc),
retweet_count=10,
like_count=20,
match_id=1,
source="twitter"
)
db_session.add(tweet)
db_session.commit()
db_session.refresh(tweet)
assert tweet.id is not None
assert tweet.tweet_id == "123456789"
assert tweet.text == "Test tweet content"
assert tweet.retweet_count == 10
assert tweet.like_count == 20
assert tweet.match_id == 1
assert tweet.source == "twitter"
def test_tweet_defaults(self, db_session: Session):
"""Test tweet default values."""
tweet = Tweet(
tweet_id="987654321",
text="Another test tweet",
created_at=datetime.now(timezone.utc)
)
db_session.add(tweet)
db_session.commit()
db_session.refresh(tweet)
assert tweet.retweet_count == 0
assert tweet.like_count == 0
assert tweet.source == "twitter"
assert tweet.match_id is None
def test_tweet_unique_constraint(self, db_session: Session):
"""Test that tweet_id must be unique."""
tweet1 = Tweet(
tweet_id="111111111",
text="First tweet",
created_at=datetime.now(timezone.utc)
)
tweet2 = Tweet(
tweet_id="111111111", # Same tweet_id
text="Second tweet",
created_at=datetime.now(timezone.utc)
)
db_session.add(tweet1)
db_session.commit()
db_session.add(tweet2)
with pytest.raises(Exception): # IntegrityError expected
db_session.commit()
def test_tweet_to_dict(self, db_session: Session):
"""Test converting tweet to dictionary."""
tweet = Tweet(
tweet_id="222222222",
text="Test tweet for dict",
created_at=datetime.now(timezone.utc),
retweet_count=5,
like_count=10,
match_id=2,
source="reddit"
)
db_session.add(tweet)
db_session.commit()
db_session.refresh(tweet)
tweet_dict = tweet.to_dict()
assert tweet_dict['tweet_id'] == "222222222"
assert tweet_dict['text'] == "Test tweet for dict"
assert tweet_dict['retweet_count'] == 5
assert tweet_dict['like_count'] == 10
assert tweet_dict['match_id'] == 2
assert tweet_dict['source'] == "reddit"
assert 'id' in tweet_dict
assert 'created_at' in tweet_dict
def test_tweet_repr(self, db_session: Session):
"""Test tweet __repr__ method."""
tweet = Tweet(
tweet_id="333333333",
text="Test tweet repr",
created_at=datetime.now(timezone.utc),
match_id=3
)
db_session.add(tweet)
db_session.commit()
db_session.refresh(tweet)
repr_str = repr(tweet)
assert "Tweet" in repr_str
assert "id=" in repr_str
assert "tweet_id=333333333" in repr_str
assert "match_id=3" in repr_str

View File

@@ -0,0 +1,186 @@
"""
Unit tests for Twitter scraper.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, patch, MagicMock
from app.scrapers.twitter_scraper import (
TwitterScraper,
RateLimitInfo,
TweetData,
create_twitter_scraper
)
@pytest.fixture
def mock_tweepy_client():
"""Mock tweepy Client."""
with patch('app.scrapers.twitter_scraper.tweepy.Client') as mock:
yield mock
@pytest.fixture
def test_bearer_token():
"""Test bearer token."""
return "test_bearer_token_12345"
@pytest.fixture
def test_scraper(test_bearer_token):
"""Create test scraper instance."""
with patch('app.scrapers.twitter_scraper.tweepy.Client'):
scraper = TwitterScraper(
bearer_token=test_bearer_token,
max_tweets_per_hour=100,
rate_limit_alert_threshold=0.9,
vip_match_ids=[1, 2, 3]
)
return scraper
class TestRateLimitInfo:
"""Test RateLimitInfo dataclass."""
def test_usage_percentage(self):
"""Test usage percentage calculation."""
info = RateLimitInfo(remaining=100, limit=1000, reset_time=None)
assert info.usage_percentage == 0.9
info = RateLimitInfo(remaining=0, limit=1000, reset_time=None)
assert info.usage_percentage == 1.0
info = RateLimitInfo(remaining=1000, limit=1000, reset_time=None)
assert info.usage_percentage == 0.0
class TestTweetData:
"""Test TweetData dataclass."""
def test_tweet_data_creation(self):
"""Test creating TweetData instance."""
tweet = TweetData(
tweet_id="123456789",
text="Test tweet content",
created_at=datetime.now(timezone.utc),
retweet_count=10,
like_count=20,
match_id=1
)
assert tweet.tweet_id == "123456789"
assert tweet.source == "twitter"
assert tweet.match_id == 1
assert tweet.retweet_count == 10
assert tweet.like_count == 20
class TestTwitterScraper:
"""Test TwitterScraper class."""
def test_scraper_initialization(self, test_bearer_token):
"""Test scraper initialization."""
with patch('app.scrapers.twitter_scraper.tweepy.Client'):
scraper = TwitterScraper(
bearer_token=test_bearer_token,
max_tweets_per_hour=1000,
rate_limit_alert_threshold=0.9,
vip_match_ids=[1, 2, 3]
)
assert scraper.bearer_token == test_bearer_token
assert scraper.max_tweets_per_hour == 1000
assert scraper.rate_limit_alert_threshold == 0.9
assert scraper.vip_match_ids == [1, 2, 3]
assert scraper.vip_mode_only is False
assert scraper.api_calls_made == 0
def test_check_rate_limit_normal(self, test_scraper):
"""Test rate limit check under normal conditions."""
test_scraper.api_calls_made = 800 # 80% usage
assert test_scraper._check_rate_limit() is True
def test_check_rate_limit_alert(self, test_scraper, caplog):
"""Test rate limit alert at threshold."""
test_scraper.api_calls_made = 900 # 90% usage
with caplog.at_level("WARNING"):
result = test_scraper._check_rate_limit()
assert result is True
assert "Rate limit approaching" in caplog.text
def test_check_rate_limit_exceeded(self, test_scraper, caplog):
"""Test rate limit exceeded."""
test_scraper.api_calls_made = 1000 # 100% usage
with caplog.at_level("ERROR"):
result = test_scraper._check_rate_limit()
assert result is False
assert "Rate limit reached" in caplog.text
def test_enable_vip_mode_only(self, test_scraper, caplog):
"""Test enabling VIP mode."""
with caplog.at_level("WARNING"):
test_scraper._enable_vip_mode_only()
assert test_scraper.vip_mode_only is True
assert "ENTERING DEGRADED MODE" in caplog.text
assert "VIP match IDs:" in caplog.text
def test_scrape_non_vip_in_vip_mode(self, test_scraper):
"""Test scraping non-VIP match when VIP mode is active."""
test_scraper.vip_mode_only = True
with pytest.raises(ValueError, match="Match 4 is not VIP"):
test_scraper.scrape_twitter_match(
match_id=4,
keywords=["test"],
max_results=10
)
def test_scrape_vip_in_vip_mode(self, test_scraper, mock_tweepy_client):
"""Test scraping VIP match when VIP mode is active."""
test_scraper.vip_mode_only = True
# Mock API response
mock_response = Mock()
mock_response.data = [
Mock(
id="123456789",
text="Test tweet",
created_at=datetime.now(timezone.utc),
public_metrics={'retweet_count': 10, 'like_count': 20}
)
]
mock_tweepy_client.return_value.search_recent_tweets.return_value = mock_response
result = test_scraper.scrape_twitter_match(
match_id=1,
keywords=["test"],
max_results=10
)
assert len(result) == 1
assert result[0].tweet_id == "123456789"
assert test_scraper.api_calls_made == 1
class TestCreateTwitterScraper:
"""Test create_twitter_scraper factory function."""
def test_factory_function(self, test_bearer_token):
"""Test factory function creates scraper."""
with patch('app.scrapers.twitter_scraper.TwitterScraper') as MockScraper:
mock_instance = Mock()
MockScraper.return_value = mock_instance
result = create_twitter_scraper(
bearer_token=test_bearer_token,
vip_match_ids=[1, 2, 3]
)
MockScraper.assert_called_once()
assert result == mock_instance