Initial commit
This commit is contained in:
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests package for backend."""
|
||||
184
backend/tests/conftest.py
Normal file
184
backend/tests/conftest.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Pytest configuration and fixtures for backend tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
from app.database import Base, get_db
|
||||
from app.models import (
|
||||
User,
|
||||
Tweet,
|
||||
RedditPost,
|
||||
RedditComment,
|
||||
SentimentScore,
|
||||
EnergyScore
|
||||
)
|
||||
|
||||
|
||||
# Create in-memory SQLite database for testing
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""
|
||||
Create a fresh database session for each test.
|
||||
|
||||
This fixture creates a new database, creates all tables,
|
||||
and returns a session. The database is cleaned up after the test.
|
||||
"""
|
||||
# Create all tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create session
|
||||
session = TestingSessionLocal()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
# Drop all tables after test
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db_session):
|
||||
"""
|
||||
Create a test client with a database session override.
|
||||
|
||||
This fixture overrides the get_db dependency to use the test database.
|
||||
"""
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sample_user(db_session):
|
||||
"""
|
||||
Create a sample user for testing.
|
||||
|
||||
Returns:
|
||||
User object
|
||||
"""
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password="hashed_password",
|
||||
is_premium=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sample_tweet(db_session, sample_user):
|
||||
"""
|
||||
Create a sample tweet for testing.
|
||||
|
||||
Returns:
|
||||
Tweet object
|
||||
"""
|
||||
from datetime import datetime
|
||||
tweet = Tweet(
|
||||
tweet_id="1234567890",
|
||||
text="This is a test tweet",
|
||||
author="test_author",
|
||||
created_at=datetime.utcnow(),
|
||||
user_id=sample_user.id,
|
||||
retweets=10,
|
||||
likes=20,
|
||||
replies=5
|
||||
)
|
||||
db_session.add(tweet)
|
||||
db_session.commit()
|
||||
db_session.refresh(tweet)
|
||||
return tweet
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sample_reddit_post(db_session, sample_user):
|
||||
"""
|
||||
Create a sample Reddit post for testing.
|
||||
|
||||
Returns:
|
||||
RedditPost object
|
||||
"""
|
||||
from datetime import datetime
|
||||
reddit_post = RedditPost(
|
||||
post_id="reddit123",
|
||||
title="Test Reddit Post",
|
||||
text="This is a test Reddit post",
|
||||
author="reddit_author",
|
||||
subreddit="test_subreddit",
|
||||
created_at=datetime.utcnow(),
|
||||
user_id=sample_user.id,
|
||||
upvotes=15,
|
||||
comments=3
|
||||
)
|
||||
db_session.add(reddit_post)
|
||||
db_session.commit()
|
||||
db_session.refresh(reddit_post)
|
||||
return reddit_post
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sample_sentiment_scores(db_session, sample_tweet, sample_reddit_post):
|
||||
"""
|
||||
Create sample sentiment scores for testing.
|
||||
|
||||
Returns:
|
||||
List of SentimentScore objects
|
||||
"""
|
||||
from datetime import datetime
|
||||
twitter_sentiment = SentimentScore(
|
||||
entity_id=sample_tweet.tweet_id,
|
||||
entity_type="tweet",
|
||||
score=0.5,
|
||||
sentiment_type="positive",
|
||||
positive=0.6,
|
||||
negative=0.2,
|
||||
neutral=0.2,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
reddit_sentiment = SentimentScore(
|
||||
entity_id=sample_reddit_post.post_id,
|
||||
entity_type="reddit_post",
|
||||
score=0.3,
|
||||
sentiment_type="positive",
|
||||
positive=0.4,
|
||||
negative=0.3,
|
||||
neutral=0.3,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
db_session.add(twitter_sentiment)
|
||||
db_session.add(reddit_sentiment)
|
||||
db_session.commit()
|
||||
db_session.refresh(twitter_sentiment)
|
||||
db_session.refresh(reddit_sentiment)
|
||||
|
||||
return [twitter_sentiment, reddit_sentiment]
|
||||
72
backend/tests/run_tests.py
Normal file
72
backend/tests/run_tests.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Script pour exécuter tous les tests du backend.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def run_command(cmd: str, description: str) -> bool:
|
||||
"""
|
||||
Exécute une commande shell et retourne le succès.
|
||||
|
||||
Args:
|
||||
cmd: Commande à exécuter
|
||||
description: Description de la commande
|
||||
|
||||
Returns:
|
||||
True si la commande a réussi, False sinon
|
||||
"""
|
||||
print(f"\n🔍 {description}")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
check=True,
|
||||
capture_output=False
|
||||
)
|
||||
print(f"✅ {description} réussi !")
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ {description} échoué avec le code {e.returncode}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Fonction principale."""
|
||||
print("🧪 Exécution des tests du backend...")
|
||||
print("=" * 60)
|
||||
|
||||
all_passed = True
|
||||
|
||||
# Tests unitaires
|
||||
all_passed &= run_command(
|
||||
"pytest tests/ -v --tb=short",
|
||||
"Tests unitaires"
|
||||
)
|
||||
|
||||
# Linting
|
||||
all_passed &= run_command(
|
||||
"flake8 app/ --max-line-length=100",
|
||||
"Linting avec flake8"
|
||||
)
|
||||
|
||||
# Formatting
|
||||
all_passed &= run_command(
|
||||
"black --check app/",
|
||||
"Vérification du formatage avec black"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if all_passed:
|
||||
print("✅ Tous les tests et validations ont réussi !")
|
||||
return 0
|
||||
else:
|
||||
print("❌ Certains tests ou validations ont échoué.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
41
backend/tests/test_api_dependencies.py
Normal file
41
backend/tests/test_api_dependencies.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Tests for API dependencies.
|
||||
|
||||
This module tests API key authentication dependencies.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
class TestApiKeyDependency:
|
||||
"""Tests for API key authentication dependency."""
|
||||
|
||||
def test_missing_api_key_header(self):
|
||||
"""Test that missing X-API-Key header returns 401."""
|
||||
response = client.get("/api/v1/users/profile")
|
||||
# This endpoint should require authentication
|
||||
# For now, we'll just test the dependency structure
|
||||
|
||||
# The dependency will raise 401 if X-API-Key is missing
|
||||
# We'll verify the error format
|
||||
|
||||
def test_invalid_api_key(self):
|
||||
"""Test that invalid API key returns 401."""
|
||||
response = client.get(
|
||||
"/api/v1/users/profile",
|
||||
headers={"X-API-Key": "invalid_key_12345"}
|
||||
)
|
||||
|
||||
# Should return 401 with proper error format
|
||||
assert response.status_code in [401, 404] # 404 if endpoint doesn't exist yet
|
||||
|
||||
def test_api_key_includes_user_id(self):
|
||||
"""Test that valid API key includes user_id in dependency."""
|
||||
# This test will be implemented once protected endpoints exist
|
||||
pass
|
||||
203
backend/tests/test_api_key_service.py
Normal file
203
backend/tests/test_api_key_service.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Tests for API Key Service.
|
||||
|
||||
This module tests the API key generation, validation, and management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.apiKeyService import ApiKeyService
|
||||
from app.models.api_key import ApiKey
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user(db: Session):
|
||||
"""Create a sample user for testing."""
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
is_premium=False
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
class TestApiKeyGeneration:
|
||||
"""Tests for API key generation."""
|
||||
|
||||
def test_generate_api_key_success(self, db: Session, sample_user: User):
|
||||
"""Test successful API key generation."""
|
||||
service = ApiKeyService(db)
|
||||
result = service.generate_api_key(sample_user.id, rate_limit=100)
|
||||
|
||||
assert "api_key" in result
|
||||
assert "id" in result
|
||||
assert "key_prefix" in result
|
||||
assert "user_id" in result
|
||||
assert result["user_id"] == sample_user.id
|
||||
assert result["rate_limit"] == 100
|
||||
assert result["is_active"] == True
|
||||
assert len(result["api_key"]) > 0
|
||||
assert len(result["key_prefix"]) == 8
|
||||
|
||||
def test_generate_api_key_invalid_user(self, db: Session):
|
||||
"""Test API key generation with invalid user ID."""
|
||||
service = ApiKeyService(db)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.generate_api_key(99999)
|
||||
|
||||
assert "not found" in str(exc_info.value).lower()
|
||||
|
||||
def test_generate_api_key_stores_hash(self, db: Session, sample_user: User):
|
||||
"""Test that API key stores hash instead of plain key."""
|
||||
service = ApiKeyService(db)
|
||||
result = service.generate_api_key(sample_user.id)
|
||||
|
||||
# Query the database to verify hash is stored
|
||||
api_key_record = db.query(ApiKey).filter(
|
||||
ApiKey.id == result["id"]
|
||||
).first()
|
||||
|
||||
assert api_key_record is not None
|
||||
assert api_key_record.key_hash != result["api_key"]
|
||||
assert len(api_key_record.key_hash) == 64 # SHA-256 hash length
|
||||
assert api_key_record.key_prefix == result["api_key"][:8]
|
||||
|
||||
|
||||
class TestApiKeyValidation:
|
||||
"""Tests for API key validation."""
|
||||
|
||||
def test_validate_api_key_success(self, db: Session, sample_user: User):
|
||||
"""Test successful API key validation."""
|
||||
service = ApiKeyService(db)
|
||||
result = service.generate_api_key(sample_user.id)
|
||||
plain_api_key = result["api_key"]
|
||||
|
||||
# Validate the key
|
||||
validated = service.validate_api_key(plain_api_key)
|
||||
|
||||
assert validated is not None
|
||||
assert validated.user_id == sample_user.id
|
||||
assert validated.is_active == True
|
||||
|
||||
def test_validate_api_key_invalid(self, db: Session):
|
||||
"""Test validation with invalid API key."""
|
||||
service = ApiKeyService(db)
|
||||
validated = service.validate_api_key("invalid_key_12345")
|
||||
|
||||
assert validated is None
|
||||
|
||||
def test_validate_api_key_updates_last_used(self, db: Session, sample_user: User):
|
||||
"""Test that validation updates last_used_at timestamp."""
|
||||
service = ApiKeyService(db)
|
||||
result = service.generate_api_key(sample_user.id)
|
||||
plain_api_key = result["api_key"]
|
||||
|
||||
# Validate the key
|
||||
validated = service.validate_api_key(plain_api_key)
|
||||
|
||||
assert validated.last_used_at is not None
|
||||
|
||||
def test_validate_api_key_inactive(self, db: Session, sample_user: User):
|
||||
"""Test validation with inactive API key."""
|
||||
service = ApiKeyService(db)
|
||||
result = service.generate_api_key(sample_user.id)
|
||||
plain_api_key = result["api_key"]
|
||||
|
||||
# Deactivate the key
|
||||
api_key_record = db.query(ApiKey).filter(
|
||||
ApiKey.id == result["id"]
|
||||
).first()
|
||||
api_key_record.is_active = False
|
||||
db.commit()
|
||||
|
||||
# Try to validate
|
||||
validated = service.validate_api_key(plain_api_key)
|
||||
|
||||
assert validated is None
|
||||
|
||||
|
||||
class TestApiKeyManagement:
|
||||
"""Tests for API key management."""
|
||||
|
||||
def test_get_user_api_keys(self, db: Session, sample_user: User):
|
||||
"""Test retrieving all API keys for a user."""
|
||||
service = ApiKeyService(db)
|
||||
|
||||
# Generate 3 keys
|
||||
keys = []
|
||||
for i in range(3):
|
||||
result = service.generate_api_key(sample_user.id)
|
||||
keys.append(result)
|
||||
|
||||
# Retrieve all keys
|
||||
user_keys = service.get_user_api_keys(sample_user.id)
|
||||
|
||||
assert len(user_keys) == 3
|
||||
# Verify plain keys are not included
|
||||
for key in user_keys:
|
||||
assert "api_key" not in key
|
||||
assert "key_prefix" in key
|
||||
|
||||
def test_revoke_api_key_success(self, db: Session, sample_user: User):
|
||||
"""Test successful API key revocation."""
|
||||
service = ApiKeyService(db)
|
||||
result = service.generate_api_key(sample_user.id)
|
||||
api_key_id = result["id"]
|
||||
|
||||
# Revoke the key
|
||||
revoked = service.revoke_api_key(api_key_id, sample_user.id)
|
||||
|
||||
assert revoked is True
|
||||
|
||||
# Verify it's deactivated
|
||||
api_key_record = db.query(ApiKey).filter(
|
||||
ApiKey.id == api_key_id
|
||||
).first()
|
||||
assert api_key_record.is_active == False
|
||||
|
||||
def test_revoke_api_key_wrong_user(self, db: Session, sample_user: User):
|
||||
"""Test revoking key with wrong user ID."""
|
||||
service = ApiKeyService(db)
|
||||
result = service.generate_api_key(sample_user.id)
|
||||
api_key_id = result["id"]
|
||||
|
||||
# Try to revoke with wrong user ID
|
||||
revoked = service.revoke_api_key(api_key_id, 99999)
|
||||
|
||||
assert revoked is False
|
||||
|
||||
def test_regenerate_api_key_success(self, db: Session, sample_user: User):
|
||||
"""Test successful API key regeneration."""
|
||||
service = ApiKeyService(db)
|
||||
old_result = service.generate_api_key(sample_user.id)
|
||||
old_api_key_id = old_result["id"]
|
||||
|
||||
# Regenerate
|
||||
new_result = service.regenerate_api_key(old_api_key_id, sample_user.id)
|
||||
|
||||
assert new_result is not None
|
||||
assert "api_key" in new_result
|
||||
assert new_result["api_key"] != old_result["api_key"]
|
||||
assert new_result["id"] != old_api_key_id
|
||||
|
||||
# Verify old key is deactivated
|
||||
old_key_record = db.query(ApiKey).filter(
|
||||
ApiKey.id == old_api_key_id
|
||||
).first()
|
||||
assert old_key_record.is_active == False
|
||||
|
||||
def test_regenerate_api_key_wrong_user(self, db: Session, sample_user: User):
|
||||
"""Test regenerating key with wrong user ID."""
|
||||
service = ApiKeyService(db)
|
||||
old_result = service.generate_api_key(sample_user.id)
|
||||
|
||||
# Try to regenerate with wrong user ID
|
||||
new_result = service.regenerate_api_key(old_result["id"], 99999)
|
||||
|
||||
assert new_result is None
|
||||
515
backend/tests/test_backtesting.py
Normal file
515
backend/tests/test_backtesting.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
Tests for Backtesting Module.
|
||||
|
||||
This module contains unit tests for the backtesting functionality
|
||||
including accuracy calculation, comparison logic, and export formats.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from app.ml.backtesting import (
|
||||
validate_accuracy,
|
||||
compare_prediction,
|
||||
run_backtesting_single_match,
|
||||
run_backtesting_batch,
|
||||
export_to_json,
|
||||
export_to_csv,
|
||||
export_to_html,
|
||||
filter_matches_by_league,
|
||||
filter_matches_by_period,
|
||||
ACCURACY_VALIDATED_THRESHOLD,
|
||||
ACCURACY_ALERT_THRESHOLD
|
||||
)
|
||||
|
||||
|
||||
class TestValidateAccuracy:
|
||||
"""Tests for validate_accuracy function."""
|
||||
|
||||
def test_validate_accuracy_above_threshold(self):
|
||||
"""Test validation when accuracy >= 60%."""
|
||||
result = validate_accuracy(65.0)
|
||||
assert result == 'VALIDATED'
|
||||
|
||||
def test_validate_accuracy_at_threshold(self):
|
||||
"""Test validation when accuracy == 60%."""
|
||||
result = validate_accuracy(60.0)
|
||||
assert result == 'VALIDATED'
|
||||
|
||||
def test_validate_accuracy_below_target(self):
|
||||
"""Test validation when 55% <= accuracy < 60%."""
|
||||
result = validate_accuracy(58.0)
|
||||
assert result == 'BELOW_TARGET'
|
||||
|
||||
def test_validate_accuracy_alert(self):
|
||||
"""Test validation when accuracy < 55%."""
|
||||
result = validate_accuracy(50.0)
|
||||
assert result == 'REVISION_REQUIRED'
|
||||
|
||||
def test_validate_accuracy_boundary(self):
|
||||
"""Test validation at boundary 55%."""
|
||||
result = validate_accuracy(55.0)
|
||||
assert result == 'BELOW_TARGET'
|
||||
|
||||
def test_validate_accuracy_extreme_high(self):
|
||||
"""Test validation with perfect accuracy."""
|
||||
result = validate_accuracy(100.0)
|
||||
assert result == 'VALIDATED'
|
||||
|
||||
def test_validate_accuracy_zero(self):
|
||||
"""Test validation with zero accuracy."""
|
||||
result = validate_accuracy(0.0)
|
||||
assert result == 'REVISION_REQUIRED'
|
||||
|
||||
def test_validate_thresholds_constants(self):
|
||||
"""Test that threshold constants are properly defined."""
|
||||
assert ACCURACY_VALIDATED_THRESHOLD == 60.0
|
||||
assert ACCURACY_ALERT_THRESHOLD == 55.0
|
||||
|
||||
|
||||
class TestComparePrediction:
|
||||
"""Tests for compare_prediction function."""
|
||||
|
||||
def test_compare_home_correct(self):
|
||||
"""Test comparison when home prediction is correct."""
|
||||
result = compare_prediction('home', 'home')
|
||||
assert result is True
|
||||
|
||||
def test_compare_away_correct(self):
|
||||
"""Test comparison when away prediction is correct."""
|
||||
result = compare_prediction('away', 'away')
|
||||
assert result is True
|
||||
|
||||
def test_compare_draw_correct(self):
|
||||
"""Test comparison when draw prediction is correct."""
|
||||
result = compare_prediction('draw', 'draw')
|
||||
assert result is True
|
||||
|
||||
def test_compare_home_incorrect(self):
|
||||
"""Test comparison when home prediction is incorrect."""
|
||||
result = compare_prediction('home', 'away')
|
||||
assert result is False
|
||||
|
||||
def test_compare_case_insensitive(self):
|
||||
"""Test that comparison is case insensitive."""
|
||||
result1 = compare_prediction('HOME', 'home')
|
||||
result2 = compare_prediction('Home', 'home')
|
||||
result3 = compare_prediction('home', 'HOME')
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is True
|
||||
|
||||
|
||||
class TestRunBacktestingSingleMatch:
|
||||
"""Tests for run_backtesting_single_match function."""
|
||||
|
||||
def test_single_match_correct_prediction(self):
|
||||
"""Test backtesting for a single match with correct prediction."""
|
||||
result = run_backtesting_single_match(
|
||||
match_id=1,
|
||||
home_team='PSG',
|
||||
away_team='OM',
|
||||
home_energy=65.0,
|
||||
away_energy=45.0,
|
||||
actual_winner='home'
|
||||
)
|
||||
|
||||
assert result['match_id'] == 1
|
||||
assert result['home_team'] == 'PSG'
|
||||
assert result['away_team'] == 'OM'
|
||||
assert result['home_energy'] == 65.0
|
||||
assert result['away_energy'] == 45.0
|
||||
assert result['actual_winner'] == 'home'
|
||||
assert result['correct'] is True
|
||||
assert 'prediction' in result
|
||||
assert result['prediction']['predicted_winner'] == 'home'
|
||||
|
||||
def test_single_match_incorrect_prediction(self):
|
||||
"""Test backtesting for a single match with incorrect prediction."""
|
||||
result = run_backtesting_single_match(
|
||||
match_id=2,
|
||||
home_team='PSG',
|
||||
away_team='OM',
|
||||
home_energy=65.0,
|
||||
away_energy=45.0,
|
||||
actual_winner='away'
|
||||
)
|
||||
|
||||
assert result['match_id'] == 2
|
||||
assert result['correct'] is False
|
||||
assert result['prediction']['predicted_winner'] != 'away'
|
||||
|
||||
def test_single_match_with_equal_energy(self):
|
||||
"""Test backtesting for a match with equal energy scores."""
|
||||
result = run_backtesting_single_match(
|
||||
match_id=3,
|
||||
home_team='PSG',
|
||||
away_team='OM',
|
||||
home_energy=50.0,
|
||||
away_energy=50.0,
|
||||
actual_winner='draw'
|
||||
)
|
||||
|
||||
assert result['prediction']['predicted_winner'] == 'draw'
|
||||
assert result['correct'] is True
|
||||
|
||||
|
||||
class TestRunBacktestingBatch:
|
||||
"""Tests for run_backtesting_batch function."""
|
||||
|
||||
def test_batch_all_correct(self):
|
||||
"""Test backtesting with all correct predictions."""
|
||||
matches = [
|
||||
{
|
||||
'match_id': 1,
|
||||
'home_team': 'PSG',
|
||||
'away_team': 'OM',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0,
|
||||
'actual_winner': 'home',
|
||||
'league': 'Ligue 1'
|
||||
},
|
||||
{
|
||||
'match_id': 2,
|
||||
'home_team': 'Lyon',
|
||||
'away_team': 'Monaco',
|
||||
'home_energy': 45.0,
|
||||
'away_energy': 65.0,
|
||||
'actual_winner': 'away',
|
||||
'league': 'Ligue 1'
|
||||
}
|
||||
]
|
||||
|
||||
result = run_backtesting_batch(matches)
|
||||
|
||||
assert result['total_matches'] == 2
|
||||
assert result['correct_predictions'] == 2
|
||||
assert result['incorrect_predictions'] == 0
|
||||
assert result['accuracy'] == 100.0
|
||||
assert result['status'] == 'VALIDATED'
|
||||
assert len(result['results']) == 2
|
||||
|
||||
def test_batch_mixed_results(self):
|
||||
"""Test backtesting with mixed correct/incorrect predictions."""
|
||||
matches = [
|
||||
{
|
||||
'match_id': 1,
|
||||
'home_team': 'PSG',
|
||||
'away_team': 'OM',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0,
|
||||
'actual_winner': 'home',
|
||||
'league': 'Ligue 1'
|
||||
},
|
||||
{
|
||||
'match_id': 2,
|
||||
'home_team': 'Lyon',
|
||||
'away_team': 'Monaco',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0,
|
||||
'actual_winner': 'away',
|
||||
'league': 'Ligue 1'
|
||||
}
|
||||
]
|
||||
|
||||
result = run_backtesting_batch(matches)
|
||||
|
||||
assert result['total_matches'] == 2
|
||||
assert result['correct_predictions'] == 1
|
||||
assert result['incorrect_predictions'] == 1
|
||||
assert result['accuracy'] == 50.0
|
||||
assert result['status'] == 'REVISION_REQUIRED'
|
||||
|
||||
def test_batch_with_leagues(self):
|
||||
"""Test backtracking with multiple leagues."""
|
||||
matches = [
|
||||
{
|
||||
'match_id': 1,
|
||||
'home_team': 'PSG',
|
||||
'away_team': 'OM',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0,
|
||||
'actual_winner': 'home',
|
||||
'league': 'Ligue 1'
|
||||
},
|
||||
{
|
||||
'match_id': 2,
|
||||
'home_team': 'Man City',
|
||||
'away_team': 'Liverpool',
|
||||
'home_energy': 70.0,
|
||||
'away_energy': 50.0,
|
||||
'actual_winner': 'home',
|
||||
'league': 'Premier League'
|
||||
}
|
||||
]
|
||||
|
||||
result = run_backtesting_batch(matches)
|
||||
|
||||
assert 'metrics_by_league' in result
|
||||
assert 'Ligue 1' in result['metrics_by_league']
|
||||
assert 'Premier League' in result['metrics_by_league']
|
||||
assert result['metrics_by_league']['Ligue 1']['total'] == 1
|
||||
assert result['metrics_by_league']['Premier League']['total'] == 1
|
||||
|
||||
def test_batch_empty(self):
|
||||
"""Test backtracking with no matches."""
|
||||
result = run_backtesting_batch([])
|
||||
|
||||
assert result['total_matches'] == 0
|
||||
assert result['correct_predictions'] == 0
|
||||
assert result['incorrect_predictions'] == 0
|
||||
assert result['accuracy'] == 0.0
|
||||
|
||||
def test_batch_missing_required_field(self):
|
||||
"""Test backtracking with missing required field raises error."""
|
||||
matches = [
|
||||
{
|
||||
'match_id': 1,
|
||||
'home_team': 'PSG',
|
||||
# Missing 'away_team'
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0,
|
||||
'actual_winner': 'home'
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="missing required fields"):
|
||||
run_backtesting_batch(matches)
|
||||
|
||||
def test_batch_with_dates(self):
|
||||
"""Test backtracking with match dates."""
|
||||
match_date = datetime(2025, 1, 15, 20, 0, 0)
|
||||
matches = [
|
||||
{
|
||||
'match_id': 1,
|
||||
'home_team': 'PSG',
|
||||
'away_team': 'OM',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0,
|
||||
'actual_winner': 'home',
|
||||
'league': 'Ligue 1',
|
||||
'date': match_date
|
||||
}
|
||||
]
|
||||
|
||||
result = run_backtesting_batch(matches)
|
||||
|
||||
assert result['results'][0]['date'] == match_date.isoformat()
|
||||
|
||||
|
||||
class TestExportFormats:
|
||||
"""Tests for export functions."""
|
||||
|
||||
def test_export_to_json(self):
|
||||
"""Test JSON export format."""
|
||||
backtesting_result = {
|
||||
'total_matches': 2,
|
||||
'correct_predictions': 1,
|
||||
'incorrect_predictions': 1,
|
||||
'accuracy': 50.0,
|
||||
'status': 'REVISION_REQUIRED',
|
||||
'results': [],
|
||||
'metrics_by_league': {},
|
||||
'timestamp': '2026-01-17T10:00:00Z',
|
||||
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
|
||||
}
|
||||
|
||||
json_output = export_to_json(backtesting_result)
|
||||
|
||||
assert isinstance(json_output, str)
|
||||
assert 'total_matches' in json_output
|
||||
assert 'accuracy' in json_output
|
||||
|
||||
def test_export_to_csv(self):
|
||||
"""Test CSV export format."""
|
||||
backtesting_result = {
|
||||
'total_matches': 1,
|
||||
'correct_predictions': 1,
|
||||
'incorrect_predictions': 0,
|
||||
'accuracy': 100.0,
|
||||
'status': 'VALIDATED',
|
||||
'results': [
|
||||
{
|
||||
'match_id': 1,
|
||||
'league': 'Ligue 1',
|
||||
'date': '2026-01-15T20:00:00Z',
|
||||
'home_team': 'PSG',
|
||||
'away_team': 'OM',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0,
|
||||
'prediction': {'predicted_winner': 'home', 'confidence': 40.0},
|
||||
'actual_winner': 'home',
|
||||
'correct': True
|
||||
}
|
||||
],
|
||||
'metrics_by_league': {},
|
||||
'timestamp': '2026-01-17T10:00:00Z',
|
||||
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
|
||||
}
|
||||
|
||||
csv_output = export_to_csv(backtesting_result)
|
||||
|
||||
assert isinstance(csv_output, str)
|
||||
assert 'match_id' in csv_output
|
||||
assert 'PSG' in csv_output
|
||||
assert 'OM' in csv_output
|
||||
|
||||
def test_export_to_html(self):
|
||||
"""Test HTML export format."""
|
||||
backtesting_result = {
|
||||
'total_matches': 1,
|
||||
'correct_predictions': 1,
|
||||
'incorrect_predictions': 0,
|
||||
'accuracy': 100.0,
|
||||
'status': 'VALIDATED',
|
||||
'results': [
|
||||
{
|
||||
'match_id': 1,
|
||||
'league': 'Ligue 1',
|
||||
'date': '2026-01-15T20:00:00Z',
|
||||
'home_team': 'PSG',
|
||||
'away_team': 'OM',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0,
|
||||
'prediction': {'predicted_winner': 'home', 'confidence': 40.0},
|
||||
'actual_winner': 'home',
|
||||
'correct': True
|
||||
}
|
||||
],
|
||||
'metrics_by_league': {},
|
||||
'timestamp': '2026-01-17T10:00:00Z',
|
||||
'validation_thresholds': {'validated': 60.0, 'alert': 55.0}
|
||||
}
|
||||
|
||||
html_output = export_to_html(backtesting_result)
|
||||
|
||||
assert isinstance(html_output, str)
|
||||
assert '<html>' in html_output
|
||||
assert '</html>' in html_output
|
||||
assert 'Backtesting Report' in html_output
|
||||
assert '100.0%' in html_output
|
||||
assert 'VALIDATED' in html_output
|
||||
|
||||
|
||||
class TestFilterMatchesByLeague:
|
||||
"""Tests for filter_matches_by_league function."""
|
||||
|
||||
def test_filter_by_single_league(self):
|
||||
"""Test filtering by a single league."""
|
||||
matches = [
|
||||
{'league': 'Ligue 1', 'match_id': 1},
|
||||
{'league': 'Premier League', 'match_id': 2},
|
||||
{'league': 'Ligue 1', 'match_id': 3}
|
||||
]
|
||||
|
||||
filtered = filter_matches_by_league(matches, ['Ligue 1'])
|
||||
|
||||
assert len(filtered) == 2
|
||||
assert all(m['league'] == 'Ligue 1' for m in filtered)
|
||||
|
||||
def test_filter_by_multiple_leagues(self):
|
||||
"""Test filtering by multiple leagues."""
|
||||
matches = [
|
||||
{'league': 'Ligue 1', 'match_id': 1},
|
||||
{'league': 'Premier League', 'match_id': 2},
|
||||
{'league': 'La Liga', 'match_id': 3}
|
||||
]
|
||||
|
||||
filtered = filter_matches_by_league(matches, ['Ligue 1', 'Premier League'])
|
||||
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0]['league'] == 'Ligue 1'
|
||||
assert filtered[1]['league'] == 'Premier League'
|
||||
|
||||
def test_filter_no_leagues(self):
|
||||
"""Test that empty leagues list returns all matches."""
|
||||
matches = [
|
||||
{'league': 'Ligue 1', 'match_id': 1},
|
||||
{'league': 'Premier League', 'match_id': 2}
|
||||
]
|
||||
|
||||
filtered = filter_matches_by_league(matches, [])
|
||||
|
||||
assert len(filtered) == 2
|
||||
|
||||
def test_filter_none_leagues(self):
|
||||
"""Test that None leagues list returns all matches."""
|
||||
matches = [
|
||||
{'league': 'Ligue 1', 'match_id': 1},
|
||||
{'league': 'Premier League', 'match_id': 2}
|
||||
]
|
||||
|
||||
filtered = filter_matches_by_league(matches, None)
|
||||
|
||||
assert len(filtered) == 2
|
||||
|
||||
|
||||
class TestFilterMatchesByPeriod:
|
||||
"""Tests for filter_matches_by_period function."""
|
||||
|
||||
def test_filter_by_start_date(self):
|
||||
"""Test filtering by start date."""
|
||||
matches = [
|
||||
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
||||
{'date': datetime(2025, 1, 20), 'match_id': 2},
|
||||
{'date': datetime(2025, 1, 5), 'match_id': 3}
|
||||
]
|
||||
|
||||
start_date = datetime(2025, 1, 15)
|
||||
filtered = filter_matches_by_period(matches, start_date=start_date)
|
||||
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]['match_id'] == 2
|
||||
|
||||
def test_filter_by_end_date(self):
|
||||
"""Test filtering by end date."""
|
||||
matches = [
|
||||
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
||||
{'date': datetime(2025, 1, 20), 'match_id': 2},
|
||||
{'date': datetime(2025, 1, 5), 'match_id': 3}
|
||||
]
|
||||
|
||||
end_date = datetime(2025, 1, 15)
|
||||
filtered = filter_matches_by_period(matches, end_date=end_date)
|
||||
|
||||
assert len(filtered) == 2
|
||||
assert sorted([m['match_id'] for m in filtered]) == [1, 3]
|
||||
|
||||
def test_filter_by_date_range(self):
|
||||
"""Test filtering by date range."""
|
||||
matches = [
|
||||
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
||||
{'date': datetime(2025, 1, 20), 'match_id': 2},
|
||||
{'date': datetime(2025, 1, 15), 'match_id': 3},
|
||||
{'date': datetime(2025, 1, 5), 'match_id': 4}
|
||||
]
|
||||
|
||||
start_date = datetime(2025, 1, 10)
|
||||
end_date = datetime(2025, 1, 15)
|
||||
filtered = filter_matches_by_period(matches, start_date=start_date, end_date=end_date)
|
||||
|
||||
assert len(filtered) == 2
|
||||
assert sorted([m['match_id'] for m in filtered]) == [1, 3]
|
||||
|
||||
def test_filter_no_dates(self):
|
||||
"""Test that None dates return all matches."""
|
||||
matches = [
|
||||
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
||||
{'date': datetime(2025, 1, 20), 'match_id': 2}
|
||||
]
|
||||
|
||||
filtered = filter_matches_by_period(matches, start_date=None, end_date=None)
|
||||
|
||||
assert len(filtered) == 2
|
||||
|
||||
def test_filter_no_date_field(self):
|
||||
"""Test matches without date field are excluded when filtering."""
|
||||
matches = [
|
||||
{'date': datetime(2025, 1, 10), 'match_id': 1},
|
||||
{'match_id': 2} # No date field
|
||||
]
|
||||
|
||||
start_date = datetime(2025, 1, 1)
|
||||
filtered = filter_matches_by_period(matches, start_date=start_date)
|
||||
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]['match_id'] == 1
|
||||
89
backend/tests/test_directory_structure.py
Normal file
89
backend/tests/test_directory_structure.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests pour valider la structure du répertoire backend."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_backend_directory_exists():
|
||||
"""Teste que le répertoire backend existe."""
|
||||
backend_path = Path(__file__).parent.parent
|
||||
assert backend_path.exists(), "Le répertoire backend doit exister"
|
||||
assert backend_path.is_dir(), "Le backend doit être un répertoire"
|
||||
|
||||
|
||||
def test_app_directory_exists():
|
||||
"""Teste que le répertoire app existe."""
|
||||
app_path = Path(__file__).parent.parent / "app"
|
||||
assert app_path.exists(), "Le répertoire app doit exister"
|
||||
assert app_path.is_dir(), "L'app doit être un répertoire"
|
||||
|
||||
|
||||
def test_app_init_exists():
|
||||
"""Teste que le fichier __init__.py existe dans app."""
|
||||
init_path = Path(__file__).parent.parent / "app" / "__init__.py"
|
||||
assert init_path.exists(), "Le fichier app/__init__.py doit exister"
|
||||
assert init_path.is_file(), "L'__init__.py doit être un fichier"
|
||||
|
||||
|
||||
def test_models_directory_exists():
|
||||
"""Teste que le répertoire models existe."""
|
||||
models_path = Path(__file__).parent.parent / "app" / "models"
|
||||
assert models_path.exists(), "Le répertoire models doit exister"
|
||||
assert models_path.is_dir(), "Les models doivent être un répertoire"
|
||||
|
||||
|
||||
def test_schemas_directory_exists():
|
||||
"""Teste que le répertoire schemas existe."""
|
||||
schemas_path = Path(__file__).parent.parent / "app" / "schemas"
|
||||
assert schemas_path.exists(), "Le répertoire schemas doit exister"
|
||||
assert schemas_path.is_dir(), "Les schemas doivent être un répertoire"
|
||||
|
||||
|
||||
def test_api_directory_exists():
|
||||
"""Teste que le répertoire api existe."""
|
||||
api_path = Path(__file__).parent.parent / "app" / "api"
|
||||
assert api_path.exists(), "Le répertoire api doit exister"
|
||||
assert api_path.is_dir(), "L'api doit être un répertoire"
|
||||
|
||||
|
||||
def test_models_init_exists():
|
||||
"""Teste que le fichier __init__.py existe dans models."""
|
||||
init_path = Path(__file__).parent.parent / "app" / "models" / "__init__.py"
|
||||
assert init_path.exists(), "Le fichier models/__init__.py doit exister"
|
||||
assert init_path.is_file(), "L'__init__.py des models doit être un fichier"
|
||||
|
||||
|
||||
def test_schemas_init_exists():
|
||||
"""Teste que le fichier __init__.py existe dans schemas."""
|
||||
init_path = Path(__file__).parent.parent / "app" / "schemas" / "__init__.py"
|
||||
assert init_path.exists(), "Le fichier schemas/__init__.py doit exister"
|
||||
assert init_path.is_file(), "L'__init__.py des schemas doit être un fichier"
|
||||
|
||||
|
||||
def test_api_init_exists():
|
||||
"""Teste que le fichier __init__.py existe dans api."""
|
||||
init_path = Path(__file__).parent.parent / "app" / "api" / "__init__.py"
|
||||
assert init_path.exists(), "Le fichier api/__init__.py doit exister"
|
||||
assert init_path.is_file(), "L'__init__.py de l'api doit être un fichier"
|
||||
|
||||
|
||||
def test_main_py_exists():
|
||||
"""Teste que le fichier main.py existe."""
|
||||
main_path = Path(__file__).parent.parent / "app" / "main.py"
|
||||
assert main_path.exists(), "Le fichier app/main.py doit exister"
|
||||
assert main_path.is_file(), "Le main.py doit être un fichier"
|
||||
|
||||
|
||||
def test_database_py_exists():
|
||||
"""Teste que le fichier database.py existe."""
|
||||
db_path = Path(__file__).parent.parent / "app" / "database.py"
|
||||
assert db_path.exists(), "Le fichier app/database.py doit exister"
|
||||
assert db_path.is_file(), "Le database.py doit être un fichier"
|
||||
|
||||
|
||||
def test_fastapi_app_importable():
|
||||
"""Teste que l'application FastAPI peut être importée."""
|
||||
try:
|
||||
from app.main import app
|
||||
assert app is not None, "L'application FastAPI doit être importable"
|
||||
except ImportError as e:
|
||||
raise AssertionError(f"Impossible d'importer l'application FastAPI: {e}")
|
||||
378
backend/tests/test_energy_calculator.py
Normal file
378
backend/tests/test_energy_calculator.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Tests for Energy Calculator Module.
|
||||
|
||||
This module tests the energy score calculation functionality including:
|
||||
- Multi-source weighted calculation
|
||||
- Temporal weighting
|
||||
- Degraded mode (when sources are unavailable)
|
||||
- Score normalization (0-100)
|
||||
- Confidence calculation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from app.ml.energy_calculator import (
|
||||
calculate_energy_score,
|
||||
apply_temporal_weighting,
|
||||
calculate_confidence,
|
||||
normalize_score,
|
||||
apply_source_weights,
|
||||
adjust_weights_for_degraded_mode
|
||||
)
|
||||
|
||||
|
||||
class TestCalculateEnergyScore:
|
||||
"""Test the main energy score calculation function."""
|
||||
|
||||
def test_calculate_energy_score_with_all_sources(self):
|
||||
"""Test energy calculation with all three sources available."""
|
||||
# Arrange
|
||||
match_id = 1
|
||||
team_id = 1
|
||||
twitter_sentiments = [
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
reddit_sentiments = [
|
||||
{'compound': 0.3, 'positive': 0.4, 'negative': 0.3, 'neutral': 0.3, 'sentiment': 'positive'}
|
||||
]
|
||||
rss_sentiments = [
|
||||
{'compound': 0.4, 'positive': 0.5, 'negative': 0.2, 'neutral': 0.3, 'sentiment': 'positive'}
|
||||
]
|
||||
|
||||
# Act
|
||||
result = calculate_energy_score(
|
||||
match_id=match_id,
|
||||
team_id=team_id,
|
||||
twitter_sentiments=twitter_sentiments,
|
||||
reddit_sentiments=reddit_sentiments,
|
||||
rss_sentiments=rss_sentiments
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert 'score' in result
|
||||
assert 'confidence' in result
|
||||
assert 'sources_used' in result
|
||||
assert 0 <= result['score'] <= 100
|
||||
assert len(result['sources_used']) == 3
|
||||
assert result['confidence'] > 0.6 # High confidence with all sources
|
||||
|
||||
def test_calculate_energy_score_with_twitter_only(self):
|
||||
"""Test energy calculation with only Twitter source available (degraded mode)."""
|
||||
# Arrange
|
||||
match_id = 1
|
||||
team_id = 1
|
||||
twitter_sentiments = [
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
reddit_sentiments = []
|
||||
rss_sentiments = []
|
||||
|
||||
# Act
|
||||
result = calculate_energy_score(
|
||||
match_id=match_id,
|
||||
team_id=team_id,
|
||||
twitter_sentiments=twitter_sentiments,
|
||||
reddit_sentiments=reddit_sentiments,
|
||||
rss_sentiments=rss_sentiments
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert 'score' in result
|
||||
assert 'confidence' in result
|
||||
assert 'sources_used' in result
|
||||
assert 0 <= result['score'] <= 100
|
||||
assert len(result['sources_used']) == 1
|
||||
assert 'twitter' in result['sources_used']
|
||||
assert result['confidence'] < 0.6 # Lower confidence in degraded mode
|
||||
|
||||
def test_calculate_energy_score_no_sentiment_data(self):
|
||||
"""Test energy calculation with no sentiment data."""
|
||||
# Arrange
|
||||
match_id = 1
|
||||
team_id = 1
|
||||
|
||||
# Act
|
||||
result = calculate_energy_score(
|
||||
match_id=match_id,
|
||||
team_id=team_id,
|
||||
twitter_sentiments=[],
|
||||
reddit_sentiments=[],
|
||||
rss_sentiments=[]
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result['score'] == 0.0
|
||||
assert result['confidence'] == 0.0
|
||||
assert len(result['sources_used']) == 0
|
||||
|
||||
|
||||
class TestApplySourceWeights:
|
||||
"""Test source weight application and adjustment."""
|
||||
|
||||
def test_apply_source_weights_all_sources(self):
|
||||
"""Test applying weights with all sources available."""
|
||||
# Arrange
|
||||
twitter_score = 50.0
|
||||
reddit_score = 40.0
|
||||
rss_score = 30.0
|
||||
available_sources = ['twitter', 'reddit', 'rss']
|
||||
|
||||
# Act
|
||||
weighted_score = apply_source_weights(
|
||||
twitter_score=twitter_score,
|
||||
reddit_score=reddit_score,
|
||||
rss_score=rss_score,
|
||||
available_sources=available_sources
|
||||
)
|
||||
|
||||
# Assert
|
||||
expected = (50.0 * 0.60) + (40.0 * 0.25) + (30.0 * 0.15)
|
||||
assert weighted_score == expected
|
||||
|
||||
def test_apply_source_weights_twitter_only(self):
|
||||
"""Test applying weights with only Twitter available."""
|
||||
# Arrange
|
||||
twitter_score = 50.0
|
||||
reddit_score = 0.0
|
||||
rss_score = 0.0
|
||||
available_sources = ['twitter']
|
||||
|
||||
# Act
|
||||
weighted_score = apply_source_weights(
|
||||
twitter_score=twitter_score,
|
||||
reddit_score=reddit_score,
|
||||
rss_score=rss_score,
|
||||
available_sources=available_sources
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert weighted_score == 50.0 # 100% of weight goes to Twitter
|
||||
|
||||
def test_adjust_weights_for_degraded_mode(self):
|
||||
"""Test weight adjustment in degraded mode."""
|
||||
# Arrange
|
||||
original_weights = {
|
||||
'twitter': 0.60,
|
||||
'reddit': 0.25,
|
||||
'rss': 0.15
|
||||
}
|
||||
available_sources = ['twitter', 'reddit']
|
||||
|
||||
# Act
|
||||
adjusted_weights = adjust_weights_for_degraded_mode(
|
||||
original_weights=original_weights,
|
||||
available_sources=available_sources
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(adjusted_weights) == 2
|
||||
assert 'twitter' in adjusted_weights
|
||||
assert 'reddit' in adjusted_weights
|
||||
assert 'rss' not in adjusted_weights
|
||||
# Total should be 1.0
|
||||
total_weight = sum(adjusted_weights.values())
|
||||
assert abs(total_weight - 1.0) < 0.001
|
||||
|
||||
|
||||
class TestApplyTemporalWeighting:
|
||||
"""Test temporal weighting of sentiment scores."""
|
||||
|
||||
def test_temporal_weighting_recent_tweets(self):
|
||||
"""Test temporal weighting with recent tweets (within 1 hour)."""
|
||||
# Arrange
|
||||
base_score = 50.0
|
||||
now = datetime.utcnow()
|
||||
tweets_with_timestamps = [
|
||||
{
|
||||
'compound': 0.5,
|
||||
'created_at': now - timedelta(minutes=30) # 30 minutes ago
|
||||
},
|
||||
{
|
||||
'compound': 0.6,
|
||||
'created_at': now - timedelta(minutes=15) # 15 minutes ago
|
||||
}
|
||||
]
|
||||
|
||||
# Act
|
||||
weighted_score = apply_temporal_weighting(
|
||||
base_score=base_score,
|
||||
tweets_with_timestamps=tweets_with_timestamps
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Recent tweets should have higher weight (close to 1.0)
|
||||
# Score should be close to original but weighted by recency
|
||||
assert 0 <= weighted_score <= 100
|
||||
|
||||
def test_temporal_weighting_old_tweets(self):
|
||||
"""Test temporal weighting with old tweets (24+ hours)."""
|
||||
# Arrange
|
||||
base_score = 50.0
|
||||
now = datetime.utcnow()
|
||||
tweets_with_timestamps = [
|
||||
{
|
||||
'compound': 0.5,
|
||||
'created_at': now - timedelta(hours=30) # 30 hours ago
|
||||
},
|
||||
{
|
||||
'compound': 0.6,
|
||||
'created_at': now - timedelta(hours=25) # 25 hours ago
|
||||
}
|
||||
]
|
||||
|
||||
# Act
|
||||
weighted_score = apply_temporal_weighting(
|
||||
base_score=base_score,
|
||||
tweets_with_timestamps=tweets_with_timestamps
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Old tweets should have lower weight (around 0.5)
|
||||
assert 0 <= weighted_score <= 100
|
||||
|
||||
|
||||
class TestNormalizeScore:
|
||||
"""Test score normalization to 0-100 range."""
|
||||
|
||||
def test_normalize_score_negative(self):
|
||||
"""Test normalization of negative scores."""
|
||||
# Act
|
||||
normalized = normalize_score(-1.5)
|
||||
|
||||
# Assert
|
||||
assert normalized == 0.0
|
||||
|
||||
def test_normalize_score_positive(self):
|
||||
"""Test normalization of positive scores."""
|
||||
# Act
|
||||
normalized = normalize_score(150.0)
|
||||
|
||||
# Assert
|
||||
assert normalized == 100.0
|
||||
|
||||
def test_normalize_score_in_range(self):
|
||||
"""Test normalization of scores within range."""
|
||||
# Arrange & Act
|
||||
normalized_0 = normalize_score(0.0)
|
||||
normalized_50 = normalize_score(50.0)
|
||||
normalized_100 = normalize_score(100.0)
|
||||
|
||||
# Assert
|
||||
assert normalized_0 == 0.0
|
||||
assert normalized_50 == 50.0
|
||||
assert normalized_100 == 100.0
|
||||
|
||||
|
||||
class TestCalculateConfidence:
|
||||
"""Test confidence level calculation."""
|
||||
|
||||
def test_confidence_all_sources(self):
|
||||
"""Test confidence with all sources available."""
|
||||
# Arrange
|
||||
available_sources = ['twitter', 'reddit', 'rss']
|
||||
total_weight = 1.0
|
||||
|
||||
# Act
|
||||
confidence = calculate_confidence(
|
||||
available_sources=available_sources,
|
||||
total_weight=total_weight
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert confidence > 0.6 # High confidence
|
||||
assert confidence <= 1.0
|
||||
|
||||
def test_confidence_single_source(self):
|
||||
"""Test confidence with only one source available."""
|
||||
# Arrange
|
||||
available_sources = ['twitter']
|
||||
total_weight = 0.6
|
||||
|
||||
# Act
|
||||
confidence = calculate_confidence(
|
||||
available_sources=available_sources,
|
||||
total_weight=total_weight
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert confidence < 0.6 # Lower confidence
|
||||
assert confidence >= 0.3
|
||||
|
||||
def test_confidence_no_sources(self):
|
||||
"""Test confidence with no sources available."""
|
||||
# Arrange
|
||||
available_sources = []
|
||||
total_weight = 0.0
|
||||
|
||||
# Act
|
||||
confidence = calculate_confidence(
|
||||
available_sources=available_sources,
|
||||
total_weight=total_weight
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert confidence == 0.0
|
||||
|
||||
|
||||
class TestEnergyFormula:
|
||||
"""Test the core energy formula: Score = (Positive - Negative) × Volume × Virality."""
|
||||
|
||||
def test_formula_positive_sentiment(self):
|
||||
"""Test energy formula with positive sentiment."""
|
||||
# Arrange
|
||||
positive_ratio = 0.7
|
||||
negative_ratio = 0.2
|
||||
volume = 100
|
||||
virality = 1.5
|
||||
|
||||
# Act
|
||||
score = (positive_ratio - negative_ratio) * volume * virality
|
||||
|
||||
# Assert
|
||||
assert score > 0 # Positive energy
|
||||
|
||||
def test_formula_negative_sentiment(self):
|
||||
"""Test energy formula with negative sentiment."""
|
||||
# Arrange
|
||||
positive_ratio = 0.2
|
||||
negative_ratio = 0.7
|
||||
volume = 100
|
||||
virality = 1.5
|
||||
|
||||
# Act
|
||||
score = (positive_ratio - negative_ratio) * volume * virality
|
||||
|
||||
# Assert
|
||||
assert score < 0 # Negative energy
|
||||
|
||||
def test_formula_neutral_sentiment(self):
|
||||
"""Test energy formula with neutral sentiment."""
|
||||
# Arrange
|
||||
positive_ratio = 0.5
|
||||
negative_ratio = 0.5
|
||||
volume = 100
|
||||
virality = 1.5
|
||||
|
||||
# Act
|
||||
score = (positive_ratio - negative_ratio) * volume * virality
|
||||
|
||||
# Assert
|
||||
assert score == 0.0 # Neutral energy
|
||||
|
||||
def test_formula_zero_volume(self):
|
||||
"""Test energy formula with zero volume."""
|
||||
# Arrange
|
||||
positive_ratio = 0.7
|
||||
negative_ratio = 0.2
|
||||
volume = 0
|
||||
virality = 1.5
|
||||
|
||||
# Act
|
||||
score = (positive_ratio - negative_ratio) * volume * virality
|
||||
|
||||
# Assert
|
||||
assert score == 0.0 # No volume means no energy
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
282
backend/tests/test_energy_manual.py
Normal file
282
backend/tests/test_energy_manual.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Manual test script for energy calculator.
|
||||
|
||||
This script runs basic tests to verify the energy calculator implementation.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from app.ml.energy_calculator import (
|
||||
calculate_energy_score,
|
||||
apply_source_weights,
|
||||
adjust_weights_for_degraded_mode,
|
||||
apply_temporal_weighting,
|
||||
normalize_score,
|
||||
calculate_confidence
|
||||
)
|
||||
|
||||
|
||||
def test_apply_source_weights():
|
||||
"""Test source weight application."""
|
||||
print("\n=== Test: Apply Source Weights ===")
|
||||
|
||||
# Test with all sources
|
||||
weighted_score = apply_source_weights(
|
||||
twitter_score=50.0,
|
||||
reddit_score=40.0,
|
||||
rss_score=30.0,
|
||||
available_sources=['twitter', 'reddit', 'rss']
|
||||
)
|
||||
expected = (50.0 * 0.60) + (40.0 * 0.25) + (30.0 * 0.15)
|
||||
assert weighted_score == expected, f"Expected {expected}, got {weighted_score}"
|
||||
print(f"✓ All sources: {weighted_score} (expected: {expected})")
|
||||
|
||||
# Test with Twitter only
|
||||
weighted_score = apply_source_weights(
|
||||
twitter_score=50.0,
|
||||
reddit_score=0.0,
|
||||
rss_score=0.0,
|
||||
available_sources=['twitter']
|
||||
)
|
||||
assert weighted_score == 50.0, f"Expected 50.0, got {weighted_score}"
|
||||
print(f"✓ Twitter only: {weighted_score}")
|
||||
|
||||
|
||||
def test_adjust_weights_for_degraded_mode():
|
||||
"""Test degraded mode weight adjustment."""
|
||||
print("\n=== Test: Adjust Weights for Degraded Mode ===")
|
||||
|
||||
original_weights = {
|
||||
'twitter': 0.60,
|
||||
'reddit': 0.25,
|
||||
'rss': 0.15
|
||||
}
|
||||
|
||||
# Test with Twitter and Reddit only
|
||||
adjusted = adjust_weights_for_degraded_mode(
|
||||
original_weights=original_weights,
|
||||
available_sources=['twitter', 'reddit']
|
||||
)
|
||||
total_weight = sum(adjusted.values())
|
||||
assert abs(total_weight - 1.0) < 0.001, f"Total weight should be 1.0, got {total_weight}"
|
||||
print(f"✓ Twitter+Reddit weights: {adjusted} (total: {total_weight})")
|
||||
|
||||
# Test with only Twitter
|
||||
adjusted = adjust_weights_for_degraded_mode(
|
||||
original_weights=original_weights,
|
||||
available_sources=['twitter']
|
||||
)
|
||||
total_weight = sum(adjusted.values())
|
||||
assert adjusted['twitter'] == 1.0, f"Twitter weight should be 1.0, got {adjusted['twitter']}"
|
||||
print(f"✓ Twitter only: {adjusted}")
|
||||
|
||||
|
||||
def test_normalize_score():
|
||||
"""Test score normalization."""
|
||||
print("\n=== Test: Normalize Score ===")
|
||||
|
||||
# Test negative score
|
||||
normalized = normalize_score(-10.0)
|
||||
assert normalized == 0.0, f"Expected 0.0, got {normalized}"
|
||||
print(f"✓ Negative score: -10.0 → {normalized}")
|
||||
|
||||
# Test score above 100
|
||||
normalized = normalize_score(150.0)
|
||||
assert normalized == 100.0, f"Expected 100.0, got {normalized}"
|
||||
print(f"✓ Score above 100: 150.0 → {normalized}")
|
||||
|
||||
# Test score in range
|
||||
normalized = normalize_score(50.0)
|
||||
assert normalized == 50.0, f"Expected 50.0, got {normalized}"
|
||||
print(f"✓ Score in range: 50.0 → {normalized}")
|
||||
|
||||
|
||||
def test_calculate_confidence():
|
||||
"""Test confidence calculation."""
|
||||
print("\n=== Test: Calculate Confidence ===")
|
||||
|
||||
# Test with all sources
|
||||
confidence = calculate_confidence(
|
||||
available_sources=['twitter', 'reddit', 'rss'],
|
||||
total_weight=1.0
|
||||
)
|
||||
assert confidence > 0.6, f"Confidence should be > 0.6, got {confidence}"
|
||||
print(f"✓ All sources: {confidence}")
|
||||
|
||||
# Test with single source (Twitter)
|
||||
confidence = calculate_confidence(
|
||||
available_sources=['twitter'],
|
||||
total_weight=0.6
|
||||
)
|
||||
assert confidence == 0.6, f"Expected 0.6, got {confidence}"
|
||||
print(f"✓ Twitter only: {confidence}")
|
||||
|
||||
# Test with no sources
|
||||
confidence = calculate_confidence(
|
||||
available_sources=[],
|
||||
total_weight=0.0
|
||||
)
|
||||
assert confidence == 0.0, f"Expected 0.0, got {confidence}"
|
||||
print(f"✓ No sources: {confidence}")
|
||||
|
||||
|
||||
def test_apply_temporal_weighting():
|
||||
"""Test temporal weighting."""
|
||||
print("\n=== Test: Apply Temporal Weighting ===")
|
||||
|
||||
base_score = 50.0
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Test with recent tweets (within 1 hour)
|
||||
recent_tweets = [
|
||||
{
|
||||
'compound': 0.5,
|
||||
'created_at': now - timedelta(minutes=30)
|
||||
},
|
||||
{
|
||||
'compound': 0.6,
|
||||
'created_at': now - timedelta(minutes=15)
|
||||
}
|
||||
]
|
||||
weighted_score = apply_temporal_weighting(
|
||||
base_score=base_score,
|
||||
tweets_with_timestamps=recent_tweets
|
||||
)
|
||||
assert 0 <= weighted_score <= 100, f"Score should be between 0 and 100, got {weighted_score}"
|
||||
print(f"✓ Recent tweets: {base_score} → {weighted_score}")
|
||||
|
||||
# Test with old tweets (24+ hours)
|
||||
old_tweets = [
|
||||
{
|
||||
'compound': 0.5,
|
||||
'created_at': now - timedelta(hours=30)
|
||||
},
|
||||
{
|
||||
'compound': 0.6,
|
||||
'created_at': now - timedelta(hours=25)
|
||||
}
|
||||
]
|
||||
weighted_score = apply_temporal_weighting(
|
||||
base_score=base_score,
|
||||
tweets_with_timestamps=old_tweets
|
||||
)
|
||||
assert 0 <= weighted_score <= 100, f"Score should be between 0 and 100, got {weighted_score}"
|
||||
print(f"✓ Old tweets: {base_score} → {weighted_score}")
|
||||
|
||||
|
||||
def test_calculate_energy_score_all_sources():
|
||||
"""Test energy score calculation with all sources."""
|
||||
print("\n=== Test: Calculate Energy Score (All Sources) ===")
|
||||
|
||||
twitter_sentiments = [
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'},
|
||||
{'compound': 0.7, 'positive': 0.7, 'negative': 0.1, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
|
||||
reddit_sentiments = [
|
||||
{'compound': 0.3, 'positive': 0.4, 'negative': 0.3, 'neutral': 0.3, 'sentiment': 'positive'}
|
||||
]
|
||||
|
||||
rss_sentiments = [
|
||||
{'compound': 0.4, 'positive': 0.5, 'negative': 0.2, 'neutral': 0.3, 'sentiment': 'positive'}
|
||||
]
|
||||
|
||||
result = calculate_energy_score(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=twitter_sentiments,
|
||||
reddit_sentiments=reddit_sentiments,
|
||||
rss_sentiments=rss_sentiments
|
||||
)
|
||||
|
||||
assert 'score' in result, "Result should contain 'score'"
|
||||
assert 'confidence' in result, "Result should contain 'confidence'"
|
||||
assert 'sources_used' in result, "Result should contain 'sources_used'"
|
||||
assert 0 <= result['score'] <= 100, f"Score should be between 0 and 100, got {result['score']}"
|
||||
assert len(result['sources_used']) == 3, f"Should use 3 sources, got {len(result['sources_used'])}"
|
||||
assert result['confidence'] > 0.6, f"Confidence should be > 0.6, got {result['confidence']}"
|
||||
|
||||
print(f"✓ Score: {result['score']}")
|
||||
print(f"✓ Confidence: {result['confidence']}")
|
||||
print(f"✓ Sources used: {result['sources_used']}")
|
||||
|
||||
|
||||
def test_calculate_energy_score_degraded_mode():
|
||||
"""Test energy score calculation in degraded mode."""
|
||||
print("\n=== Test: Calculate Energy Score (Degraded Mode) ===")
|
||||
|
||||
twitter_sentiments = [
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
|
||||
result = calculate_energy_score(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=twitter_sentiments,
|
||||
reddit_sentiments=[],
|
||||
rss_sentiments=[]
|
||||
)
|
||||
|
||||
assert result['score'] >= 0, f"Score should be >= 0, got {result['score']}"
|
||||
assert result['confidence'] < 0.6, f"Confidence should be < 0.6 in degraded mode, got {result['confidence']}"
|
||||
assert len(result['sources_used']) == 1, f"Should use 1 source, got {len(result['sources_used'])}"
|
||||
assert 'twitter' in result['sources_used'], "Should use Twitter"
|
||||
|
||||
print(f"✓ Score: {result['score']}")
|
||||
print(f"✓ Confidence: {result['confidence']} (lower in degraded mode)")
|
||||
print(f"✓ Sources used: {result['sources_used']}")
|
||||
|
||||
|
||||
def test_calculate_energy_score_no_data():
|
||||
"""Test energy score calculation with no data."""
|
||||
print("\n=== Test: Calculate Energy Score (No Data) ===")
|
||||
|
||||
result = calculate_energy_score(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[],
|
||||
reddit_sentiments=[],
|
||||
rss_sentiments=[]
|
||||
)
|
||||
|
||||
assert result['score'] == 0.0, f"Score should be 0.0, got {result['score']}"
|
||||
assert result['confidence'] == 0.0, f"Confidence should be 0.0, got {result['confidence']}"
|
||||
assert len(result['sources_used']) == 0, f"Should use 0 sources, got {len(result['sources_used'])}"
|
||||
|
||||
print(f"✓ Score: {result['score']}")
|
||||
print(f"✓ Confidence: {result['confidence']}")
|
||||
print(f"✓ Sources used: {result['sources_used']}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all manual tests."""
|
||||
print("=" * 60)
|
||||
print("MANUAL TESTS FOR ENERGY CALCULATOR")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_apply_source_weights()
|
||||
test_adjust_weights_for_degraded_mode()
|
||||
test_normalize_score()
|
||||
test_calculate_confidence()
|
||||
test_apply_temporal_weighting()
|
||||
test_calculate_energy_score_all_sources()
|
||||
test_calculate_energy_score_degraded_mode()
|
||||
test_calculate_energy_score_no_data()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ ALL TESTS PASSED!")
|
||||
print("=" * 60)
|
||||
return 0
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n❌ TEST FAILED: {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ UNEXPECTED ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
394
backend/tests/test_energy_service.py
Normal file
394
backend/tests/test_energy_service.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
Tests for Energy Service.
|
||||
|
||||
This module tests the energy service business logic including:
|
||||
- Energy score calculation and storage
|
||||
- Retrieval of energy scores
|
||||
- Updating and deleting energy scores
|
||||
- Filtering and querying
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from app.services.energy_service import (
|
||||
calculate_and_store_energy_score,
|
||||
get_energy_score,
|
||||
get_energy_scores_by_match,
|
||||
get_energy_scores_by_team,
|
||||
get_energy_score_by_match_and_team,
|
||||
update_energy_score,
|
||||
delete_energy_score,
|
||||
list_energy_scores
|
||||
)
|
||||
from app.schemas.energy_score import EnergyScoreCalculationRequest, EnergyScoreUpdate
|
||||
|
||||
|
||||
class TestCalculateAndStoreEnergyScore:
|
||||
"""Test energy score calculation and storage."""
|
||||
|
||||
def test_calculate_and_store_with_all_sources(self, db_session):
|
||||
"""Test calculation and storage with all sources available."""
|
||||
# Arrange
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
],
|
||||
reddit_sentiments=[
|
||||
{'compound': 0.3, 'positive': 0.4, 'negative': 0.3, 'neutral': 0.3, 'sentiment': 'positive'}
|
||||
],
|
||||
rss_sentiments=[
|
||||
{'compound': 0.4, 'positive': 0.5, 'negative': 0.2, 'neutral': 0.3, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
|
||||
# Act
|
||||
energy_score = calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Assert
|
||||
assert energy_score is not None
|
||||
assert energy_score.match_id == 1
|
||||
assert energy_score.team_id == 1
|
||||
assert 0 <= energy_score.score <= 100
|
||||
assert energy_score.confidence > 0
|
||||
assert len(energy_score.sources_used) == 3
|
||||
assert energy_score.twitter_score is not None
|
||||
assert energy_score.reddit_score is not None
|
||||
assert energy_score.rss_score is not None
|
||||
|
||||
def test_calculate_and_store_with_twitter_only(self, db_session):
|
||||
"""Test calculation and storage with only Twitter source."""
|
||||
# Arrange
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
|
||||
# Act
|
||||
energy_score = calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Assert
|
||||
assert energy_score is not None
|
||||
assert energy_score.match_id == 1
|
||||
assert energy_score.team_id == 1
|
||||
assert len(energy_score.sources_used) == 1
|
||||
assert 'twitter' in energy_score.sources_used
|
||||
assert energy_score.twitter_score is not None
|
||||
assert energy_score.reddit_score is None
|
||||
assert energy_score.rss_score is None
|
||||
# Lower confidence in degraded mode
|
||||
assert energy_score.confidence < 0.7
|
||||
|
||||
def test_calculate_and_store_no_sentiment_data(self, db_session):
|
||||
"""Test calculation with no sentiment data."""
|
||||
# Arrange
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[],
|
||||
reddit_sentiments=[],
|
||||
rss_sentiments=[]
|
||||
)
|
||||
|
||||
# Act
|
||||
energy_score = calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Assert
|
||||
assert energy_score is not None
|
||||
assert energy_score.score == 0.0
|
||||
assert energy_score.confidence == 0.0
|
||||
assert len(energy_score.sources_used) == 0
|
||||
|
||||
|
||||
class TestGetEnergyScore:
|
||||
"""Test retrieval of energy scores."""
|
||||
|
||||
def test_get_energy_score_by_id(self, db_session):
|
||||
"""Test retrieving energy score by ID."""
|
||||
# Arrange
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
created = calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
retrieved = get_energy_score(db_session, created.id)
|
||||
|
||||
# Assert
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == created.id
|
||||
assert retrieved.match_id == created.match_id
|
||||
assert retrieved.team_id == created.team_id
|
||||
assert retrieved.score == created.score
|
||||
|
||||
def test_get_energy_score_not_found(self, db_session):
|
||||
"""Test retrieving non-existent energy score."""
|
||||
# Act
|
||||
retrieved = get_energy_score(db_session, 99999)
|
||||
|
||||
# Assert
|
||||
assert retrieved is None
|
||||
|
||||
|
||||
class TestGetEnergyScoresByMatch:
|
||||
"""Test retrieval of energy scores by match."""
|
||||
|
||||
def test_get_energy_scores_by_match_id(self, db_session):
|
||||
"""Test retrieving all energy scores for a specific match."""
|
||||
# Arrange
|
||||
# Create energy scores for different matches
|
||||
for match_id in [1, 1, 2]:
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=match_id,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
scores_for_match_1 = get_energy_scores_by_match(db_session, 1)
|
||||
scores_for_match_2 = get_energy_scores_by_match(db_session, 2)
|
||||
|
||||
# Assert
|
||||
assert len(scores_for_match_1) == 2
|
||||
assert all(score.match_id == 1 for score in scores_for_match_1)
|
||||
assert len(scores_for_match_2) == 1
|
||||
assert scores_for_match_2[0].match_id == 2
|
||||
|
||||
|
||||
class TestGetEnergyScoresByTeam:
|
||||
"""Test retrieval of energy scores by team."""
|
||||
|
||||
def test_get_energy_scores_by_team_id(self, db_session):
|
||||
"""Test retrieving all energy scores for a specific team."""
|
||||
# Arrange
|
||||
# Create energy scores for different teams
|
||||
for team_id in [1, 1, 2]:
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=team_id,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
scores_for_team_1 = get_energy_scores_by_team(db_session, 1)
|
||||
scores_for_team_2 = get_energy_scores_by_team(db_session, 2)
|
||||
|
||||
# Assert
|
||||
assert len(scores_for_team_1) == 2
|
||||
assert all(score.team_id == 1 for score in scores_for_team_1)
|
||||
assert len(scores_for_team_2) == 1
|
||||
assert scores_for_team_2[0].team_id == 2
|
||||
|
||||
|
||||
class TestGetEnergyScoreByMatchAndTeam:
|
||||
"""Test retrieval of most recent energy score by match and team."""
|
||||
|
||||
def test_get_energy_score_by_match_and_team(self, db_session):
|
||||
"""Test retrieving most recent energy score for match and team."""
|
||||
# Arrange
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
retrieved = get_energy_score_by_match_and_team(db_session, 1, 1)
|
||||
|
||||
# Assert
|
||||
assert retrieved is not None
|
||||
assert retrieved.match_id == 1
|
||||
assert retrieved.team_id == 1
|
||||
|
||||
def test_get_energy_score_by_match_and_team_not_found(self, db_session):
|
||||
"""Test retrieving non-existent energy score."""
|
||||
# Act
|
||||
retrieved = get_energy_score_by_match_and_team(db_session, 999, 999)
|
||||
|
||||
# Assert
|
||||
assert retrieved is None
|
||||
|
||||
|
||||
class TestUpdateEnergyScore:
|
||||
"""Test updating energy scores."""
|
||||
|
||||
def test_update_energy_score(self, db_session):
|
||||
"""Test updating an existing energy score."""
|
||||
# Arrange
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
created = calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
update = EnergyScoreUpdate(
|
||||
score=75.0,
|
||||
confidence=0.9
|
||||
)
|
||||
updated = update_energy_score(db_session, created.id, update)
|
||||
|
||||
# Assert
|
||||
assert updated is not None
|
||||
assert updated.id == created.id
|
||||
assert updated.score == 75.0
|
||||
assert updated.confidence == 0.9
|
||||
assert updated.updated_at > created.updated_at
|
||||
|
||||
def test_update_energy_score_not_found(self, db_session):
|
||||
"""Test updating non-existent energy score."""
|
||||
# Arrange
|
||||
update = EnergyScoreUpdate(score=75.0, confidence=0.9)
|
||||
|
||||
# Act
|
||||
updated = update_energy_score(db_session, 99999, update)
|
||||
|
||||
# Assert
|
||||
assert updated is None
|
||||
|
||||
|
||||
class TestDeleteEnergyScore:
|
||||
"""Test deleting energy scores."""
|
||||
|
||||
def test_delete_energy_score(self, db_session):
|
||||
"""Test deleting an existing energy score."""
|
||||
# Arrange
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
created = calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
deleted = delete_energy_score(db_session, created.id)
|
||||
|
||||
# Assert
|
||||
assert deleted is True
|
||||
retrieved = get_energy_score(db_session, created.id)
|
||||
assert retrieved is None
|
||||
|
||||
def test_delete_energy_score_not_found(self, db_session):
|
||||
"""Test deleting non-existent energy score."""
|
||||
# Act
|
||||
deleted = delete_energy_score(db_session, 99999)
|
||||
|
||||
# Assert
|
||||
assert deleted is False
|
||||
|
||||
|
||||
class TestListEnergyScores:
|
||||
"""Test listing and filtering energy scores."""
|
||||
|
||||
def test_list_energy_scores_default(self, db_session):
|
||||
"""Test listing energy scores with default parameters."""
|
||||
# Arrange
|
||||
for i in range(5):
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=i,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
scores = list_energy_scores(db_session)
|
||||
|
||||
# Assert
|
||||
assert len(scores) == 5
|
||||
|
||||
def test_list_energy_scores_with_match_filter(self, db_session):
|
||||
"""Test listing energy scores filtered by match ID."""
|
||||
# Arrange
|
||||
for match_id in [1, 1, 2]:
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=match_id,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
scores = list_energy_scores(db_session, match_id=1)
|
||||
|
||||
# Assert
|
||||
assert len(scores) == 2
|
||||
assert all(score.match_id == 1 for score in scores)
|
||||
|
||||
def test_list_energy_scores_with_score_filter(self, db_session):
|
||||
"""Test listing energy scores filtered by score range."""
|
||||
# Arrange
|
||||
for score in [30.0, 50.0, 70.0, 90.0]:
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=1,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
created = calculate_and_store_energy_score(db_session, request)
|
||||
# Manually update score for testing
|
||||
created.score = score
|
||||
db_session.commit()
|
||||
|
||||
# Act
|
||||
scores = list_energy_scores(db_session, min_score=40.0, max_score=80.0)
|
||||
|
||||
# Assert
|
||||
assert len(scores) == 2
|
||||
assert all(40.0 <= score.score <= 80.0 for score in scores)
|
||||
|
||||
def test_list_energy_scores_with_pagination(self, db_session):
|
||||
"""Test listing energy scores with pagination."""
|
||||
# Arrange
|
||||
for i in range(10):
|
||||
request = EnergyScoreCalculationRequest(
|
||||
match_id=i,
|
||||
team_id=1,
|
||||
twitter_sentiments=[
|
||||
{'compound': 0.5, 'positive': 0.6, 'negative': 0.2, 'neutral': 0.2, 'sentiment': 'positive'}
|
||||
]
|
||||
)
|
||||
calculate_and_store_energy_score(db_session, request)
|
||||
|
||||
# Act
|
||||
page1 = list_energy_scores(db_session, limit=5, offset=0)
|
||||
page2 = list_energy_scores(db_session, limit=5, offset=5)
|
||||
|
||||
# Assert
|
||||
assert len(page1) == 5
|
||||
assert len(page2) == 5
|
||||
# Verify no overlap
|
||||
page1_ids = [score.id for score in page1]
|
||||
page2_ids = [score.id for score in page2]
|
||||
assert set(page1_ids).isdisjoint(set(page2_ids))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
281
backend/tests/test_energy_worker.py
Normal file
281
backend/tests/test_energy_worker.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Tests for energy calculation worker.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.workers.energy_worker import (
|
||||
EnergyWorker,
|
||||
create_energy_worker
|
||||
)
|
||||
|
||||
|
||||
class TestEnergyWorker:
|
||||
"""Tests for EnergyWorker class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test energy worker initialization."""
|
||||
worker = EnergyWorker()
|
||||
|
||||
# No specific initialization required for energy worker
|
||||
assert worker is not None
|
||||
|
||||
@patch('app.workers.energy_worker.calculate_and_store_energy_score')
|
||||
@patch('app.workers.energy_worker.get_energy_score_by_match_and_team')
|
||||
def test_execute_energy_calculation_task_new(
|
||||
self,
|
||||
mock_get_score,
|
||||
mock_calculate_store
|
||||
):
|
||||
"""Test executing energy calculation for new score."""
|
||||
# Create worker
|
||||
worker = EnergyWorker()
|
||||
|
||||
# Mock database session
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Mock no existing score
|
||||
mock_get_score.return_value = None
|
||||
|
||||
# Mock calculation
|
||||
mock_energy_score = Mock()
|
||||
mock_energy_score.score = 75.5
|
||||
mock_energy_score.confidence = 0.82
|
||||
mock_energy_score.sources_used = ['twitter', 'reddit']
|
||||
mock_energy_score.twitter_score = 0.6
|
||||
mock_energy_score.reddit_score = 0.7
|
||||
mock_energy_score.rss_score = 0.5
|
||||
mock_energy_score.temporal_factor = 1.0
|
||||
mock_calculate_store.return_value = mock_energy_score
|
||||
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'team_id': 456,
|
||||
'twitter_sentiments': [
|
||||
{'compound': 0.6, 'sentiment': 'positive'},
|
||||
{'compound': 0.7, 'sentiment': 'positive'}
|
||||
],
|
||||
'reddit_sentiments': [
|
||||
{'compound': 0.7, 'sentiment': 'positive'},
|
||||
{'compound': 0.8, 'sentiment': 'positive'}
|
||||
],
|
||||
'rss_sentiments': [
|
||||
{'compound': 0.5, 'sentiment': 'positive'}
|
||||
],
|
||||
'tweets_with_timestamps': []
|
||||
}
|
||||
|
||||
result = worker.execute_energy_calculation_task(task, mock_db)
|
||||
|
||||
# Verify calculation called
|
||||
mock_calculate_store.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
assert result['energy_score'] == 75.5
|
||||
assert result['confidence'] == 0.82
|
||||
assert result['sources_used'] == ['twitter', 'reddit']
|
||||
assert result['status'] == 'success'
|
||||
assert result['metadata']['match_id'] == 123
|
||||
assert result['metadata']['team_id'] == 456
|
||||
assert result['metadata']['twitter_score'] == 0.6
|
||||
assert result['metadata']['reddit_score'] == 0.7
|
||||
assert result['metadata']['rss_score'] == 0.5
|
||||
assert result['metadata']['temporal_factor'] == 1.0
|
||||
|
||||
@patch('app.workers.energy_worker.get_energy_score_by_match_and_team')
|
||||
def test_execute_energy_calculation_task_existing(
|
||||
self,
|
||||
mock_get_score
|
||||
):
|
||||
"""Test executing energy calculation for existing score."""
|
||||
# Create worker
|
||||
worker = EnergyWorker()
|
||||
|
||||
# Mock database session
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Mock existing score
|
||||
mock_existing_score = Mock()
|
||||
mock_existing_score.score = 72.0
|
||||
mock_existing_score.confidence = 0.78
|
||||
mock_existing_score.sources_used = ['twitter', 'reddit']
|
||||
mock_get_score.return_value = mock_existing_score
|
||||
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'team_id': 456,
|
||||
'twitter_sentiments': [],
|
||||
'reddit_sentiments': [],
|
||||
'rss_sentiments': [],
|
||||
'tweets_with_timestamps': []
|
||||
}
|
||||
|
||||
result = worker.execute_energy_calculation_task(task, mock_db)
|
||||
|
||||
# Verify result from existing score
|
||||
assert result['energy_score'] == 72.0
|
||||
assert result['confidence'] == 0.78
|
||||
assert result['sources_used'] == ['twitter', 'reddit']
|
||||
assert result['status'] == 'success'
|
||||
assert result['metadata']['updated_existing'] is True
|
||||
|
||||
@patch('app.workers.energy_worker.calculate_and_store_energy_score')
|
||||
@patch('app.workers.energy_worker.get_energy_score_by_match_and_team')
|
||||
def test_execute_energy_calculation_task_error_handling(
|
||||
self,
|
||||
mock_get_score,
|
||||
mock_calculate_store
|
||||
):
|
||||
"""Test error handling in energy calculation."""
|
||||
worker = EnergyWorker()
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Mock no existing score
|
||||
mock_get_score.return_value = None
|
||||
|
||||
# Mock calculation error
|
||||
mock_calculate_store.side_effect = Exception("Calculation error")
|
||||
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'team_id': 456,
|
||||
'twitter_sentiments': [],
|
||||
'reddit_sentiments': [],
|
||||
'rss_sentiments': [],
|
||||
'tweets_with_timestamps': []
|
||||
}
|
||||
|
||||
result = worker.execute_energy_calculation_task(task, mock_db)
|
||||
|
||||
# Verify error handling
|
||||
assert result['energy_score'] == 0.0
|
||||
assert result['confidence'] == 0.0
|
||||
assert result['sources_used'] == []
|
||||
assert result['status'] == 'error'
|
||||
assert 'error' in result
|
||||
|
||||
@patch('app.workers.energy_worker.calculate_energy_score')
|
||||
def test_calculate_mock_energy(self, mock_calculate):
|
||||
"""Test mock energy calculation (for testing)."""
|
||||
worker = EnergyWorker()
|
||||
|
||||
# Mock calculation
|
||||
mock_calculate.return_value = {
|
||||
'score': 75.5,
|
||||
'confidence': 0.82,
|
||||
'sources_used': ['twitter', 'reddit']
|
||||
}
|
||||
|
||||
# Calculate mock energy
|
||||
twitter_sentiments = [
|
||||
{'compound': 0.6, 'sentiment': 'positive'}
|
||||
]
|
||||
reddit_sentiments = [
|
||||
{'compound': 0.7, 'sentiment': 'positive'}
|
||||
]
|
||||
|
||||
result = worker.calculate_mock_energy(
|
||||
twitter_sentiments=twitter_sentiments,
|
||||
reddit_sentiments=reddit_sentiments
|
||||
)
|
||||
|
||||
# Verify calculation called
|
||||
mock_calculate.assert_called_once_with(
|
||||
match_id=0,
|
||||
team_id=0,
|
||||
twitter_sentiments=twitter_sentiments,
|
||||
reddit_sentiments=reddit_sentiments,
|
||||
rss_sentiments=[],
|
||||
tweets_with_timestamps=[]
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result['energy_score'] == 75.5
|
||||
assert result['confidence'] == 0.82
|
||||
assert result['sources_used'] == ['twitter', 'reddit']
|
||||
assert result['status'] == 'success'
|
||||
|
||||
@patch('app.workers.energy_worker.calculate_energy_score')
|
||||
def test_calculate_mock_energy_error(self, mock_calculate):
|
||||
"""Test error handling in mock energy calculation."""
|
||||
worker = EnergyWorker()
|
||||
|
||||
# Mock calculation error
|
||||
mock_calculate.side_effect = Exception("Calculation error")
|
||||
|
||||
# Calculate mock energy
|
||||
result = worker.calculate_mock_energy(
|
||||
twitter_sentiments=[],
|
||||
reddit_sentiments=[]
|
||||
)
|
||||
|
||||
# Verify error handling
|
||||
assert result['energy_score'] == 0.0
|
||||
assert result['confidence'] == 0.0
|
||||
assert result['sources_used'] == []
|
||||
assert result['status'] == 'error'
|
||||
assert 'error' in result
|
||||
|
||||
@patch('app.workers.energy_worker.calculate_energy_score')
|
||||
def test_calculate_mock_energy_with_all_sources(self, mock_calculate):
|
||||
"""Test mock energy calculation with all sources."""
|
||||
worker = EnergyWorker()
|
||||
|
||||
# Mock calculation
|
||||
mock_calculate.return_value = {
|
||||
'score': 80.0,
|
||||
'confidence': 0.85,
|
||||
'sources_used': ['twitter', 'reddit', 'rss']
|
||||
}
|
||||
|
||||
# Calculate mock energy with all sources
|
||||
twitter_sentiments = [
|
||||
{'compound': 0.6, 'sentiment': 'positive'}
|
||||
]
|
||||
reddit_sentiments = [
|
||||
{'compound': 0.7, 'sentiment': 'positive'}
|
||||
]
|
||||
rss_sentiments = [
|
||||
{'compound': 0.5, 'sentiment': 'positive'}
|
||||
]
|
||||
tweets_with_timestamps = [
|
||||
{'timestamp': '2026-01-17T10:00:00', 'text': 'Great match!'}
|
||||
]
|
||||
|
||||
result = worker.calculate_mock_energy(
|
||||
twitter_sentiments=twitter_sentiments,
|
||||
reddit_sentiments=reddit_sentiments,
|
||||
rss_sentiments=rss_sentiments,
|
||||
tweets_with_timestamps=tweets_with_timestamps
|
||||
)
|
||||
|
||||
# Verify calculation called with all sources
|
||||
mock_calculate.assert_called_once_with(
|
||||
match_id=0,
|
||||
team_id=0,
|
||||
twitter_sentiments=twitter_sentiments,
|
||||
reddit_sentiments=reddit_sentiments,
|
||||
rss_sentiments=rss_sentiments,
|
||||
tweets_with_timestamps=tweets_with_timestamps
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result['energy_score'] == 80.0
|
||||
assert result['confidence'] == 0.85
|
||||
assert result['sources_used'] == ['twitter', 'reddit', 'rss']
|
||||
assert result['status'] == 'success'
|
||||
|
||||
|
||||
class TestCreateEnergyWorker:
|
||||
"""Tests for create_energy_worker factory function."""
|
||||
|
||||
def test_create_energy_worker(self):
|
||||
"""Test creating an energy worker."""
|
||||
worker = create_energy_worker()
|
||||
|
||||
assert isinstance(worker, EnergyWorker)
|
||||
152
backend/tests/test_match_model.py
Normal file
152
backend/tests/test_match_model.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Unit tests for Match model.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.match import Match
|
||||
from app.models.prediction import Prediction
|
||||
from app.database import Base, engine, SessionLocal
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""Create a fresh database session for each test."""
|
||||
# Create tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create session
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
class TestMatchModel:
|
||||
"""Test Match SQLAlchemy model."""
|
||||
|
||||
def test_match_creation(self, db_session: Session):
|
||||
"""Test creating a match in database."""
|
||||
match = Match(
|
||||
home_team="PSG",
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
assert match.id is not None
|
||||
assert match.home_team == "PSG"
|
||||
assert match.away_team == "Olympique de Marseille"
|
||||
assert match.league == "Ligue 1"
|
||||
assert match.status == "scheduled"
|
||||
|
||||
def test_match_required_fields(self, db_session: Session):
|
||||
"""Test that all required fields must be provided."""
|
||||
# Missing home_team
|
||||
match = Match(
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
|
||||
with pytest.raises(Exception): # IntegrityError expected
|
||||
db_session.commit()
|
||||
|
||||
def test_match_to_dict(self, db_session: Session):
|
||||
"""Test converting match to dictionary."""
|
||||
match = Match(
|
||||
home_team="Barcelona",
|
||||
away_team="Real Madrid",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="La Liga",
|
||||
status="in_progress"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
match_dict = match.to_dict()
|
||||
|
||||
assert match_dict['home_team'] == "Barcelona"
|
||||
assert match_dict['away_team'] == "Real Madrid"
|
||||
assert match_dict['league'] == "La Liga"
|
||||
assert match_dict['status'] == "in_progress"
|
||||
assert 'id' in match_dict
|
||||
assert 'date' in match_dict
|
||||
|
||||
def test_match_repr(self, db_session: Session):
|
||||
"""Test match __repr__ method."""
|
||||
match = Match(
|
||||
home_team="Manchester City",
|
||||
away_team="Liverpool",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Premier League",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
repr_str = repr(match)
|
||||
|
||||
assert "Match" in repr_str
|
||||
assert "id=" in repr_str
|
||||
assert "Manchester City" in repr_str
|
||||
assert "Liverpool" in repr_str
|
||||
|
||||
def test_match_relationships(self, db_session: Session):
|
||||
"""Test match relationships with predictions."""
|
||||
match = Match(
|
||||
home_team="Juventus",
|
||||
away_team="Inter Milan",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Serie A",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
# Create predictions
|
||||
prediction1 = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="high",
|
||||
confidence="80%",
|
||||
predicted_winner="Juventus",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
prediction2 = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="medium",
|
||||
confidence="60%",
|
||||
predicted_winner="Inter Milan",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(prediction1)
|
||||
db_session.add(prediction2)
|
||||
db_session.commit()
|
||||
|
||||
# Refresh match to load relationships
|
||||
db_session.refresh(match)
|
||||
|
||||
assert len(match.predictions) == 2
|
||||
assert match.predictions[0].match_id == match.id
|
||||
assert match.predictions[1].match_id == match.id
|
||||
215
backend/tests/test_match_schema.py
Normal file
215
backend/tests/test_match_schema.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Unit tests for Match Pydantic schemas.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.match import (
|
||||
MatchBase,
|
||||
MatchCreate,
|
||||
MatchUpdate,
|
||||
MatchResponse,
|
||||
MatchListResponse,
|
||||
MatchStatsResponse
|
||||
)
|
||||
|
||||
|
||||
class TestMatchBase:
|
||||
"""Test MatchBase schema."""
|
||||
|
||||
def test_match_base_valid(self):
|
||||
"""Test creating a valid MatchBase."""
|
||||
match_data = {
|
||||
"home_team": "PSG",
|
||||
"away_team": "Olympique de Marseille",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"league": "Ligue 1",
|
||||
"status": "scheduled"
|
||||
}
|
||||
|
||||
match = MatchBase(**match_data)
|
||||
|
||||
assert match.home_team == "PSG"
|
||||
assert match.away_team == "Olympique de Marseille"
|
||||
assert match.league == "Ligue 1"
|
||||
assert match.status == "scheduled"
|
||||
|
||||
def test_match_base_home_team_too_long(self):
|
||||
"""Test that home_team exceeds max length."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MatchBase(
|
||||
home_team="A" * 256, # Too long
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
assert "at most 255 characters" in str(exc_info.value).lower()
|
||||
|
||||
def test_match_base_away_team_empty(self):
|
||||
"""Test that away_team cannot be empty."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MatchBase(
|
||||
home_team="PSG",
|
||||
away_team="", # Empty
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
assert "at least 1 character" in str(exc_info.value).lower()
|
||||
|
||||
def test_match_base_status_too_long(self):
|
||||
"""Test that status exceeds max length."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MatchBase(
|
||||
home_team="PSG",
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="A" * 51 # Too long
|
||||
)
|
||||
|
||||
assert "at most 50 characters" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestMatchCreate:
|
||||
"""Test MatchCreate schema."""
|
||||
|
||||
def test_match_create_valid(self):
|
||||
"""Test creating a valid MatchCreate."""
|
||||
match_data = {
|
||||
"home_team": "Barcelona",
|
||||
"away_team": "Real Madrid",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"league": "La Liga",
|
||||
"status": "in_progress"
|
||||
}
|
||||
|
||||
match = MatchCreate(**match_data)
|
||||
|
||||
assert match.home_team == "Barcelona"
|
||||
assert match.away_team == "Real Madrid"
|
||||
assert match.status == "in_progress"
|
||||
|
||||
|
||||
class TestMatchUpdate:
|
||||
"""Test MatchUpdate schema."""
|
||||
|
||||
def test_match_update_partial(self):
|
||||
"""Test updating only some fields."""
|
||||
update_data = {
|
||||
"status": "completed"
|
||||
}
|
||||
|
||||
match_update = MatchUpdate(**update_data)
|
||||
|
||||
assert match_update.status == "completed"
|
||||
assert match_update.home_team is None
|
||||
assert match_update.away_team is None
|
||||
|
||||
def test_match_update_all_fields(self):
|
||||
"""Test updating all fields."""
|
||||
update_data = {
|
||||
"home_team": "Manchester City",
|
||||
"away_team": "Liverpool",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"league": "Premier League",
|
||||
"status": "in_progress"
|
||||
}
|
||||
|
||||
match_update = MatchUpdate(**update_data)
|
||||
|
||||
assert match_update.home_team == "Manchester City"
|
||||
assert match_update.away_team == "Liverpool"
|
||||
assert match_update.league == "Premier League"
|
||||
assert match_update.status == "in_progress"
|
||||
|
||||
def test_match_update_empty(self):
|
||||
"""Test that MatchUpdate can be empty."""
|
||||
match_update = MatchUpdate()
|
||||
|
||||
assert match_update.home_team is None
|
||||
assert match_update.away_team is None
|
||||
assert match_update.date is None
|
||||
assert match_update.league is None
|
||||
assert match_update.status is None
|
||||
|
||||
|
||||
class TestMatchResponse:
|
||||
"""Test MatchResponse schema."""
|
||||
|
||||
def test_match_response_from_dict(self):
|
||||
"""Test creating MatchResponse from dictionary."""
|
||||
match_dict = {
|
||||
"id": 1,
|
||||
"home_team": "Juventus",
|
||||
"away_team": "Inter Milan",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"league": "Serie A",
|
||||
"status": "scheduled"
|
||||
}
|
||||
|
||||
match = MatchResponse(**match_dict)
|
||||
|
||||
assert match.id == 1
|
||||
assert match.home_team == "Juventus"
|
||||
assert match.away_team == "Inter Milan"
|
||||
assert match.league == "Serie A"
|
||||
assert match.status == "scheduled"
|
||||
|
||||
|
||||
class TestMatchListResponse:
|
||||
"""Test MatchListResponse schema."""
|
||||
|
||||
def test_match_list_response(self):
|
||||
"""Test creating a MatchListResponse."""
|
||||
matches_data = [
|
||||
{
|
||||
"id": 1,
|
||||
"home_team": "PSG",
|
||||
"away_team": "Olympique de Marseille",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"league": "Ligue 1",
|
||||
"status": "scheduled"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"home_team": "Barcelona",
|
||||
"away_team": "Real Madrid",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"league": "La Liga",
|
||||
"status": "scheduled"
|
||||
}
|
||||
]
|
||||
|
||||
response = MatchListResponse(data=matches_data, count=2, meta={"page": 1})
|
||||
|
||||
assert response.count == 2
|
||||
assert len(response.data) == 2
|
||||
assert response.meta["page"] == 1
|
||||
|
||||
|
||||
class TestMatchStatsResponse:
|
||||
"""Test MatchStatsResponse schema."""
|
||||
|
||||
def test_match_stats_response(self):
|
||||
"""Test creating a MatchStatsResponse."""
|
||||
stats = {
|
||||
"total_matches": 10,
|
||||
"matches_by_league": {"Ligue 1": 5, "La Liga": 5},
|
||||
"matches_by_status": {"scheduled": 5, "completed": 5},
|
||||
"upcoming_matches": 5,
|
||||
"completed_matches": 5
|
||||
}
|
||||
|
||||
match_stats = MatchStatsResponse(**stats)
|
||||
|
||||
assert match_stats.total_matches == 10
|
||||
assert match_stats.matches_by_league["Ligue 1"] == 5
|
||||
assert match_stats.matches_by_status["scheduled"] == 5
|
||||
assert match_stats.upcoming_matches == 5
|
||||
assert match_stats.completed_matches == 5
|
||||
538
backend/tests/test_prediction_api.py
Normal file
538
backend/tests/test_prediction_api.py
Normal file
@@ -0,0 +1,538 @@
|
||||
"""
|
||||
Unit tests for Prediction API endpoints.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import Base, engine, SessionLocal
|
||||
from app.main import app
|
||||
from app.models.match import Match
|
||||
from app.models.prediction import Prediction
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""Create a fresh database session for each test."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_match(db_session: Session) -> Match:
|
||||
"""Create a sample match for testing."""
|
||||
match = Match(
|
||||
home_team="PSG",
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
return match
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prediction(db_session: Session, sample_match: Match) -> Prediction:
|
||||
"""Create a sample prediction for testing."""
|
||||
prediction = Prediction(
|
||||
match_id=sample_match.id,
|
||||
energy_score="high",
|
||||
confidence="75.0%",
|
||||
predicted_winner="PSG",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db_session.add(prediction)
|
||||
db_session.commit()
|
||||
db_session.refresh(prediction)
|
||||
return prediction
|
||||
|
||||
|
||||
class TestCreatePredictionEndpoint:
|
||||
"""Test POST /api/v1/predictions/matches/{match_id}/predict endpoint."""
|
||||
|
||||
def test_create_prediction_success(self, client: TestClient, sample_match: Match):
|
||||
"""Test creating a prediction successfully."""
|
||||
response = client.post(
|
||||
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
||||
params={
|
||||
"home_energy": 65.0,
|
||||
"away_energy": 45.0,
|
||||
"energy_score_label": "high"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
|
||||
assert data["match_id"] == sample_match.id
|
||||
assert data["energy_score"] == "high"
|
||||
assert "%" in data["confidence"]
|
||||
assert data["predicted_winner"] == "PSG"
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
|
||||
def test_create_prediction_default_energy_label(self, client: TestClient, sample_match: Match):
|
||||
"""Test creating prediction with auto-generated energy label."""
|
||||
response = client.post(
|
||||
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
||||
params={
|
||||
"home_energy": 75.0,
|
||||
"away_energy": 65.0
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
|
||||
assert data["energy_score"] in ["high", "very_high", "medium", "low"]
|
||||
|
||||
def test_create_prediction_home_win(self, client: TestClient, sample_match: Match):
|
||||
"""Test prediction when home team wins."""
|
||||
response = client.post(
|
||||
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
||||
params={
|
||||
"home_energy": 70.0,
|
||||
"away_energy": 30.0
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["predicted_winner"] == "PSG"
|
||||
|
||||
def test_create_prediction_away_win(self, client: TestClient, sample_match: Match):
|
||||
"""Test prediction when away team wins."""
|
||||
response = client.post(
|
||||
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
||||
params={
|
||||
"home_energy": 30.0,
|
||||
"away_energy": 70.0
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["predicted_winner"] == "Olympique de Marseille"
|
||||
|
||||
def test_create_prediction_draw(self, client: TestClient, sample_match: Match):
|
||||
"""Test prediction for draw."""
|
||||
response = client.post(
|
||||
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
||||
params={
|
||||
"home_energy": 50.0,
|
||||
"away_energy": 50.0
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["predicted_winner"] == "Draw"
|
||||
assert response.json()["confidence"] == "0.0%"
|
||||
|
||||
def test_create_prediction_nonexistent_match(self, client: TestClient):
|
||||
"""Test creating prediction for non-existent match."""
|
||||
response = client.post(
|
||||
"/api/v1/predictions/matches/999/predict",
|
||||
params={
|
||||
"home_energy": 65.0,
|
||||
"away_energy": 45.0
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_create_prediction_negative_energy(self, client: TestClient, sample_match: Match):
|
||||
"""Test creating prediction with negative energy."""
|
||||
response = client.post(
|
||||
f"/api/v1/predictions/matches/{sample_match.id}/predict",
|
||||
params={
|
||||
"home_energy": -10.0,
|
||||
"away_energy": 45.0
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestGetPredictionEndpoint:
|
||||
"""Test GET /api/v1/predictions/{prediction_id} endpoint."""
|
||||
|
||||
def test_get_prediction_success(self, client: TestClient, sample_prediction: Prediction):
|
||||
"""Test getting a prediction by ID successfully."""
|
||||
response = client.get(f"/api/v1/predictions/{sample_prediction.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == sample_prediction.id
|
||||
assert data["match_id"] == sample_prediction.match_id
|
||||
assert data["energy_score"] == sample_prediction.energy_score
|
||||
assert data["confidence"] == sample_prediction.confidence
|
||||
assert data["predicted_winner"] == sample_prediction.predicted_winner
|
||||
|
||||
def test_get_prediction_not_found(self, client: TestClient):
|
||||
"""Test getting non-existent prediction."""
|
||||
response = client.get("/api/v1/predictions/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestGetPredictionsForMatchEndpoint:
|
||||
"""Test GET /api/v1/predictions/matches/{match_id} endpoint."""
|
||||
|
||||
def test_get_predictions_for_match_success(self, client: TestClient, db_session: Session, sample_match: Match):
|
||||
"""Test getting all predictions for a match."""
|
||||
# Create multiple predictions
|
||||
from app.services.prediction_service import PredictionService
|
||||
service = PredictionService(db_session)
|
||||
|
||||
for i in range(3):
|
||||
service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=50.0 + i * 10,
|
||||
away_energy=40.0 + i * 5
|
||||
)
|
||||
|
||||
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "data" in data
|
||||
assert "count" in data
|
||||
assert data["count"] == 3
|
||||
assert len(data["data"]) == 3
|
||||
assert all(p["match_id"] == sample_match.id for p in data["data"])
|
||||
|
||||
def test_get_predictions_for_empty_match(self, client: TestClient, sample_match: Match):
|
||||
"""Test getting predictions for match with no predictions."""
|
||||
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["count"] == 0
|
||||
assert len(data["data"]) == 0
|
||||
|
||||
|
||||
class TestGetLatestPredictionForMatchEndpoint:
|
||||
"""Test GET /api/v1/predictions/matches/{match_id}/latest endpoint."""
|
||||
|
||||
def test_get_latest_prediction_success(self, client: TestClient, db_session: Session, sample_match: Match):
|
||||
"""Test getting latest prediction for a match."""
|
||||
from app.services.prediction_service import PredictionService
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create multiple predictions
|
||||
prediction1 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=50.0,
|
||||
away_energy=40.0
|
||||
)
|
||||
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
|
||||
prediction2 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=60.0,
|
||||
away_energy=35.0
|
||||
)
|
||||
|
||||
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == prediction2.id
|
||||
assert data["match_id"] == sample_match.id
|
||||
|
||||
def test_get_latest_prediction_not_found(self, client: TestClient, sample_match: Match):
|
||||
"""Test getting latest prediction for match with no predictions."""
|
||||
response = client.get(f"/api/v1/predictions/matches/{sample_match.id}/latest")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "No predictions found" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestDeletePredictionEndpoint:
|
||||
"""Test DELETE /api/v1/predictions/{prediction_id} endpoint."""
|
||||
|
||||
def test_delete_prediction_success(self, client: TestClient, sample_prediction: Prediction):
|
||||
"""Test deleting a prediction successfully."""
|
||||
response = client.delete(f"/api/v1/predictions/{sample_prediction.id}")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify prediction is deleted
|
||||
get_response = client.get(f"/api/v1/predictions/{sample_prediction.id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_delete_prediction_not_found(self, client: TestClient):
|
||||
"""Test deleting non-existent prediction."""
|
||||
response = client.delete("/api/v1/predictions/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestGetPredictionsEndpoint:
|
||||
"""Test GET /api/v1/predictions endpoint with pagination and filters."""
|
||||
|
||||
def test_get_predictions_default(self, client: TestClient, db_session: Session, sample_match: Match):
|
||||
"""Test getting predictions with default parameters."""
|
||||
from app.services.prediction_service import PredictionService
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create multiple predictions
|
||||
for i in range(5):
|
||||
service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=50.0 + i * 5,
|
||||
away_energy=40.0 + i * 3
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/predictions")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "data" in data
|
||||
assert "meta" in data
|
||||
assert "total" in data["meta"]
|
||||
assert "limit" in data["meta"]
|
||||
assert "offset" in data["meta"]
|
||||
assert "timestamp" in data["meta"]
|
||||
assert data["meta"]["version"] == "v1"
|
||||
assert data["meta"]["limit"] == 20
|
||||
assert data["meta"]["offset"] == 0
|
||||
|
||||
def test_get_predictions_with_pagination(self, client: TestClient, db_session: Session, sample_match: Match):
|
||||
"""Test getting predictions with custom pagination."""
|
||||
from app.services.prediction_service import PredictionService
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create 25 predictions
|
||||
for i in range(25):
|
||||
service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=50.0 + i,
|
||||
away_energy=40.0 + i
|
||||
)
|
||||
|
||||
# Get first page with limit 10
|
||||
response = client.get("/api/v1/predictions?limit=10&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data["data"]) == 10
|
||||
assert data["meta"]["limit"] == 10
|
||||
assert data["meta"]["offset"] == 0
|
||||
assert data["meta"]["total"] == 25
|
||||
|
||||
# Get second page
|
||||
response2 = client.get("/api/v1/predictions?limit=10&offset=10")
|
||||
|
||||
assert response2.status_code == 200
|
||||
data2 = response2.json()
|
||||
|
||||
assert len(data2["data"]) == 10
|
||||
assert data2["meta"]["offset"] == 10
|
||||
|
||||
def test_get_predictions_with_league_filter(self, client: TestClient, db_session: Session):
|
||||
"""Test filtering predictions by league."""
|
||||
from app.services.prediction_service import PredictionService
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create matches in different leagues
|
||||
match1 = Match(
|
||||
home_team="PSG",
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
match2 = Match(
|
||||
home_team="Real Madrid",
|
||||
away_team="Barcelona",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="La Liga",
|
||||
status="scheduled"
|
||||
)
|
||||
db_session.add_all([match1, match2])
|
||||
db_session.commit()
|
||||
db_session.refresh(match1)
|
||||
db_session.refresh(match2)
|
||||
|
||||
# Create predictions for both matches
|
||||
service.create_prediction_for_match(match1.id, 65.0, 45.0)
|
||||
service.create_prediction_for_match(match2.id, 60.0, 50.0)
|
||||
|
||||
# Filter by Ligue 1
|
||||
response = client.get("/api/v1/predictions?league=Ligue 1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["match"]["league"] == "Ligue 1"
|
||||
|
||||
def test_get_predictions_with_date_filter(self, client: TestClient, db_session: Session, sample_match: Match):
|
||||
"""Test filtering predictions by date range."""
|
||||
from app.services.prediction_service import PredictionService
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create predictions for matches with different dates
|
||||
future_date = datetime.now(timezone.utc).replace(day=20, month=2)
|
||||
past_date = datetime.now(timezone.utc).replace(day=1, month=1, year=2025)
|
||||
|
||||
future_match = Match(
|
||||
home_team="PSG",
|
||||
away_team="Olympique de Marseille",
|
||||
date=future_date,
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
past_match = Match(
|
||||
home_team="Real Madrid",
|
||||
away_team="Barcelona",
|
||||
date=past_date,
|
||||
league="La Liga",
|
||||
status="scheduled"
|
||||
)
|
||||
db_session.add_all([future_match, past_match])
|
||||
db_session.commit()
|
||||
db_session.refresh(future_match)
|
||||
db_session.refresh(past_match)
|
||||
|
||||
service.create_prediction_for_match(future_match.id, 65.0, 45.0)
|
||||
service.create_prediction_for_match(past_match.id, 60.0, 50.0)
|
||||
|
||||
# Filter by date range
|
||||
response = client.get(
|
||||
f"/api/v1/predictions?date_min=2025-01-01T00:00:00Z"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Should include future match but not past match
|
||||
assert len(data["data"]) >= 1
|
||||
assert any(pred["match"]["date"] >= "2025-01-01" for pred in data["data"])
|
||||
|
||||
def test_get_predictions_invalid_limit(self, client: TestClient):
|
||||
"""Test getting predictions with invalid limit."""
|
||||
response = client.get("/api/v1/predictions?limit=150")
|
||||
|
||||
# Pydantic should validate this before endpoint executes
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_get_predictions_empty(self, client: TestClient):
|
||||
"""Test getting predictions when none exist."""
|
||||
response = client.get("/api/v1/predictions")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data["data"]) == 0
|
||||
assert data["meta"]["total"] == 0
|
||||
|
||||
|
||||
class TestGetPredictionByMatchIdEndpoint:
|
||||
"""Test GET /api/v1/predictions/match/{match_id} endpoint."""
|
||||
|
||||
def test_get_prediction_by_match_id_success(self, client: TestClient, db_session: Session, sample_match: Match):
|
||||
"""Test getting prediction details by match ID."""
|
||||
from app.services.prediction_service import PredictionService
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create a prediction
|
||||
prediction = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=65.0,
|
||||
away_energy=45.0
|
||||
)
|
||||
|
||||
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == prediction.id
|
||||
assert data["match_id"] == sample_match.id
|
||||
assert "match" in data
|
||||
assert data["match"]["home_team"] == "PSG"
|
||||
assert data["match"]["away_team"] == "Olympique de Marseille"
|
||||
assert "energy_score" in data
|
||||
assert "confidence" in data
|
||||
assert "predicted_winner" in data
|
||||
assert "history" in data
|
||||
assert len(data["history"]) >= 1
|
||||
|
||||
def test_get_prediction_by_match_id_with_history(self, client: TestClient, db_session: Session, sample_match: Match):
|
||||
"""Test getting prediction with history."""
|
||||
from app.services.prediction_service import PredictionService
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create multiple predictions for the same match
|
||||
import time
|
||||
prediction1 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=50.0,
|
||||
away_energy=40.0
|
||||
)
|
||||
time.sleep(0.01)
|
||||
prediction2 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=60.0,
|
||||
away_energy=35.0
|
||||
)
|
||||
time.sleep(0.01)
|
||||
prediction3 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=70.0,
|
||||
away_energy=30.0
|
||||
)
|
||||
|
||||
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Should return latest prediction
|
||||
assert data["id"] == prediction3.id
|
||||
|
||||
# History should contain all predictions
|
||||
assert len(data["history"]) == 3
|
||||
assert data["history"][0]["id"] == prediction3.id
|
||||
assert data["history"][1]["id"] == prediction2.id
|
||||
assert data["history"][2]["id"] == prediction1.id
|
||||
|
||||
def test_get_prediction_by_match_id_not_found(self, client: TestClient):
|
||||
"""Test getting prediction for non-existent match."""
|
||||
response = client.get("/api/v1/predictions/match/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "No predictions found" in response.json()["detail"]
|
||||
|
||||
def test_get_prediction_by_match_id_no_predictions(self, client: TestClient, sample_match: Match):
|
||||
"""Test getting prediction for match without predictions."""
|
||||
response = client.get(f"/api/v1/predictions/match/{sample_match.id}")
|
||||
|
||||
assert response.status_code == 404
|
||||
212
backend/tests/test_prediction_calculator.py
Normal file
212
backend/tests/test_prediction_calculator.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Unit tests for Prediction Calculator module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from app.ml.prediction_calculator import (
|
||||
calculate_prediction,
|
||||
calculate_confidence_meter,
|
||||
determine_winner,
|
||||
validate_prediction_result
|
||||
)
|
||||
|
||||
|
||||
class TestConfidenceMeterCalculation:
|
||||
"""Test Confidence Meter calculation logic."""
|
||||
|
||||
def test_confidence_meter_equal_energy(self):
|
||||
"""Test confidence meter when both teams have equal energy."""
|
||||
result = calculate_confidence_meter(home_energy=50.0, away_energy=50.0)
|
||||
assert result == 0.0 # No confidence when teams are equal
|
||||
|
||||
def test_confidence_meter_small_difference(self):
|
||||
"""Test confidence meter with small energy difference."""
|
||||
result = calculate_confidence_meter(home_energy=52.0, away_energy=50.0)
|
||||
assert 0 < result < 100
|
||||
assert result == 4.0 # 2 * 2.0 = 4.0%
|
||||
|
||||
def test_confidence_meter_medium_difference(self):
|
||||
"""Test confidence meter with medium energy difference."""
|
||||
result = calculate_confidence_meter(home_energy=60.0, away_energy=50.0)
|
||||
assert 0 < result < 100
|
||||
assert result == 20.0 # 2 * 10.0 = 20.0%
|
||||
|
||||
def test_confidence_meter_large_difference(self):
|
||||
"""Test confidence meter with large energy difference."""
|
||||
result = calculate_confidence_meter(home_energy=80.0, away_energy=50.0)
|
||||
assert 0 < result <= 100
|
||||
assert result == 60.0 # 2 * 30.0 = 60.0%
|
||||
|
||||
def test_confidence_meter_very_large_difference(self):
|
||||
"""Test confidence meter caps at 100% for very large differences."""
|
||||
result = calculate_confidence_meter(home_energy=100.0, away_energy=50.0)
|
||||
assert result == 100.0 # Capped at 100%
|
||||
|
||||
def test_confidence_meter_zero_energy(self):
|
||||
"""Test confidence meter when one team has zero energy."""
|
||||
result = calculate_confidence_meter(home_energy=0.0, away_energy=50.0)
|
||||
assert result == 100.0 # Max confidence
|
||||
|
||||
def test_confidence_meter_negative_energy(self):
|
||||
"""Test confidence meter with negative energy values."""
|
||||
result = calculate_confidence_meter(home_energy=-10.0, away_energy=50.0)
|
||||
assert result == 100.0 # Capped at 100%
|
||||
|
||||
|
||||
class TestWinnerDetermination:
|
||||
"""Test winner determination logic."""
|
||||
|
||||
def test_home_team_wins(self):
|
||||
"""Test home team wins when home energy is higher."""
|
||||
result = determine_winner(home_energy=60.0, away_energy=40.0)
|
||||
assert result == "home"
|
||||
|
||||
def test_away_team_wins(self):
|
||||
"""Test away team wins when away energy is higher."""
|
||||
result = determine_winner(home_energy=40.0, away_energy=60.0)
|
||||
assert result == "away"
|
||||
|
||||
def test_draw_equal_energy(self):
|
||||
"""Test draw when both teams have equal energy."""
|
||||
result = determine_winner(home_energy=50.0, away_energy=50.0)
|
||||
assert result == "draw"
|
||||
|
||||
def test_draw_almost_equal(self):
|
||||
"""Test draw when energy difference is very small."""
|
||||
result = determine_winner(home_energy=50.01, away_energy=50.0)
|
||||
assert result == "home" # Should prefer home on tie
|
||||
|
||||
|
||||
class TestPredictionCalculation:
|
||||
"""Test complete prediction calculation."""
|
||||
|
||||
def test_prediction_home_win(self):
|
||||
"""Test prediction calculation for home team win."""
|
||||
result = calculate_prediction(home_energy=65.0, away_energy=45.0)
|
||||
|
||||
assert 'confidence' in result
|
||||
assert 'predicted_winner' in result
|
||||
assert 'home_energy' in result
|
||||
assert 'away_energy' in result
|
||||
|
||||
assert result['confidence'] > 0
|
||||
assert result['predicted_winner'] == 'home'
|
||||
assert result['home_energy'] == 65.0
|
||||
assert result['away_energy'] == 45.0
|
||||
|
||||
def test_prediction_away_win(self):
|
||||
"""Test prediction calculation for away team win."""
|
||||
result = calculate_prediction(home_energy=35.0, away_energy=70.0)
|
||||
|
||||
assert result['confidence'] > 0
|
||||
assert result['predicted_winner'] == 'away'
|
||||
assert result['home_energy'] == 35.0
|
||||
assert result['away_energy'] == 70.0
|
||||
|
||||
def test_prediction_draw(self):
|
||||
"""Test prediction calculation for draw."""
|
||||
result = calculate_prediction(home_energy=50.0, away_energy=50.0)
|
||||
|
||||
assert result['confidence'] == 0.0
|
||||
assert result['predicted_winner'] == 'draw'
|
||||
|
||||
def test_prediction_high_confidence(self):
|
||||
"""Test prediction with high confidence."""
|
||||
result = calculate_prediction(home_energy=90.0, away_energy=30.0)
|
||||
|
||||
assert result['confidence'] >= 80.0 # High confidence
|
||||
assert result['predicted_winner'] == 'home'
|
||||
|
||||
|
||||
class TestPredictionValidation:
|
||||
"""Test prediction result validation."""
|
||||
|
||||
def test_valid_prediction_result(self):
|
||||
"""Test validation of valid prediction result."""
|
||||
result = {
|
||||
'confidence': 75.0,
|
||||
'predicted_winner': 'home',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0
|
||||
}
|
||||
|
||||
assert validate_prediction_result(result) is True
|
||||
|
||||
def test_invalid_confidence_negative(self):
|
||||
"""Test validation fails with negative confidence."""
|
||||
result = {
|
||||
'confidence': -10.0,
|
||||
'predicted_winner': 'home',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0
|
||||
}
|
||||
|
||||
assert validate_prediction_result(result) is False
|
||||
|
||||
def test_invalid_confidence_over_100(self):
|
||||
"""Test validation fails with confidence over 100%."""
|
||||
result = {
|
||||
'confidence': 150.0,
|
||||
'predicted_winner': 'home',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0
|
||||
}
|
||||
|
||||
assert validate_prediction_result(result) is False
|
||||
|
||||
def test_invalid_winner_value(self):
|
||||
"""Test validation fails with invalid winner value."""
|
||||
result = {
|
||||
'confidence': 75.0,
|
||||
'predicted_winner': 'invalid',
|
||||
'home_energy': 65.0,
|
||||
'away_energy': 45.0
|
||||
}
|
||||
|
||||
assert validate_prediction_result(result) is False
|
||||
|
||||
def test_invalid_missing_fields(self):
|
||||
"""Test validation fails when required fields are missing."""
|
||||
result = {
|
||||
'confidence': 75.0,
|
||||
'predicted_winner': 'home'
|
||||
# Missing home_energy and away_energy
|
||||
}
|
||||
|
||||
assert validate_prediction_result(result) is False
|
||||
|
||||
def test_invalid_energy_negative(self):
|
||||
"""Test validation fails with negative energy values."""
|
||||
result = {
|
||||
'confidence': 75.0,
|
||||
'predicted_winner': 'home',
|
||||
'home_energy': -10.0,
|
||||
'away_energy': 45.0
|
||||
}
|
||||
|
||||
assert validate_prediction_result(result) is False
|
||||
|
||||
|
||||
class TestPredictionEdgeCases:
|
||||
"""Test prediction calculation edge cases."""
|
||||
|
||||
def test_both_teams_zero_energy(self):
|
||||
"""Test prediction when both teams have zero energy."""
|
||||
result = calculate_prediction(home_energy=0.0, away_energy=0.0)
|
||||
|
||||
assert result['confidence'] == 0.0
|
||||
assert result['predicted_winner'] == 'draw'
|
||||
|
||||
def test_very_high_energy_values(self):
|
||||
"""Test prediction with very high energy values."""
|
||||
result = calculate_prediction(home_energy=1000.0, away_energy=500.0)
|
||||
|
||||
assert result['confidence'] == 100.0 # Should be capped
|
||||
assert result['predicted_winner'] == 'home'
|
||||
|
||||
def test_decimal_energy_values(self):
|
||||
"""Test prediction with decimal energy values."""
|
||||
result = calculate_prediction(home_energy=55.5, away_energy=54.5)
|
||||
|
||||
assert 0 < result['confidence'] < 100
|
||||
assert result['predicted_winner'] == 'home'
|
||||
255
backend/tests/test_prediction_model.py
Normal file
255
backend/tests/test_prediction_model.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
Unit tests for Prediction model.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.match import Match
|
||||
from app.models.prediction import Prediction
|
||||
from app.database import Base, engine, SessionLocal
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""Create a fresh database session for each test."""
|
||||
# Create tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create session
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
class TestPredictionModel:
|
||||
"""Test Prediction SQLAlchemy model."""
|
||||
|
||||
def test_prediction_creation(self, db_session: Session):
|
||||
"""Test creating a prediction in database."""
|
||||
# First create a match
|
||||
match = Match(
|
||||
home_team="PSG",
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
# Create prediction
|
||||
prediction = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="high",
|
||||
confidence="85%",
|
||||
predicted_winner="PSG",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(prediction)
|
||||
db_session.commit()
|
||||
db_session.refresh(prediction)
|
||||
|
||||
assert prediction.id is not None
|
||||
assert prediction.match_id == match.id
|
||||
assert prediction.energy_score == "high"
|
||||
assert prediction.confidence == "85%"
|
||||
assert prediction.predicted_winner == "PSG"
|
||||
|
||||
def test_prediction_required_fields(self, db_session: Session):
|
||||
"""Test that all required fields must be provided."""
|
||||
# Create a match first
|
||||
match = Match(
|
||||
home_team="Barcelona",
|
||||
away_team="Real Madrid",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="La Liga",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
# Missing predicted_winner
|
||||
prediction = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="high",
|
||||
confidence="90%",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(prediction)
|
||||
|
||||
with pytest.raises(Exception): # IntegrityError expected
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_foreign_key_constraint(self, db_session: Session):
|
||||
"""Test that match_id must reference an existing match."""
|
||||
prediction = Prediction(
|
||||
match_id=999, # Non-existent match
|
||||
energy_score="high",
|
||||
confidence="90%",
|
||||
predicted_winner="PSG",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(prediction)
|
||||
|
||||
with pytest.raises(Exception): # IntegrityError expected
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_to_dict(self, db_session: Session):
|
||||
"""Test converting prediction to dictionary."""
|
||||
match = Match(
|
||||
home_team="Manchester City",
|
||||
away_team="Liverpool",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Premier League",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
prediction = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="medium",
|
||||
confidence="70%",
|
||||
predicted_winner="Liverpool",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(prediction)
|
||||
db_session.commit()
|
||||
db_session.refresh(prediction)
|
||||
|
||||
prediction_dict = prediction.to_dict()
|
||||
|
||||
assert prediction_dict['match_id'] == match.id
|
||||
assert prediction_dict['energy_score'] == "medium"
|
||||
assert prediction_dict['confidence'] == "70%"
|
||||
assert prediction_dict['predicted_winner'] == "Liverpool"
|
||||
assert 'id' in prediction_dict
|
||||
assert 'created_at' in prediction_dict
|
||||
|
||||
def test_prediction_repr(self, db_session: Session):
|
||||
"""Test prediction __repr__ method."""
|
||||
match = Match(
|
||||
home_team="Juventus",
|
||||
away_team="Inter Milan",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Serie A",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
prediction = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="low",
|
||||
confidence="50%",
|
||||
predicted_winner="Juventus",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(prediction)
|
||||
db_session.commit()
|
||||
db_session.refresh(prediction)
|
||||
|
||||
repr_str = repr(prediction)
|
||||
|
||||
assert "Prediction" in repr_str
|
||||
assert "id=" in repr_str
|
||||
assert "match_id=" in repr_str
|
||||
assert "confidence=50%" in repr_str
|
||||
|
||||
def test_prediction_match_relationship(self, db_session: Session):
|
||||
"""Test prediction relationship with match."""
|
||||
match = Match(
|
||||
home_team="Bayern Munich",
|
||||
away_team="Borussia Dortmund",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Bundesliga",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
prediction = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="very_high",
|
||||
confidence="95%",
|
||||
predicted_winner="Bayern Munich",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(prediction)
|
||||
db_session.commit()
|
||||
db_session.refresh(prediction)
|
||||
|
||||
# Access relationship
|
||||
assert prediction.match.id == match.id
|
||||
assert prediction.match.home_team == "Bayern Munich"
|
||||
assert prediction.match.away_team == "Borussia Dortmund"
|
||||
|
||||
def test_multiple_predictions_per_match(self, db_session: Session):
|
||||
"""Test that a match can have multiple predictions."""
|
||||
match = Match(
|
||||
home_team="Ajax",
|
||||
away_team="Feyenoord",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Eredivisie",
|
||||
status="scheduled"
|
||||
)
|
||||
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
|
||||
# Create multiple predictions
|
||||
prediction1 = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="high",
|
||||
confidence="80%",
|
||||
predicted_winner="Ajax",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
prediction2 = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="medium",
|
||||
confidence="60%",
|
||||
predicted_winner="Ajax",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
prediction3 = Prediction(
|
||||
match_id=match.id,
|
||||
energy_score="low",
|
||||
confidence="40%",
|
||||
predicted_winner="Feyenoord",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(prediction1)
|
||||
db_session.add(prediction2)
|
||||
db_session.add(prediction3)
|
||||
db_session.commit()
|
||||
|
||||
# Refresh match and check predictions
|
||||
db_session.refresh(match)
|
||||
assert len(match.predictions) == 3
|
||||
211
backend/tests/test_prediction_schema.py
Normal file
211
backend/tests/test_prediction_schema.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Unit tests for Prediction Pydantic schemas.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.prediction import (
|
||||
PredictionBase,
|
||||
PredictionCreate,
|
||||
PredictionUpdate,
|
||||
PredictionResponse,
|
||||
PredictionListResponse,
|
||||
PredictionStatsResponse
|
||||
)
|
||||
|
||||
|
||||
class TestPredictionBase:
|
||||
"""Test PredictionBase schema."""
|
||||
|
||||
def test_prediction_base_valid(self):
|
||||
"""Test creating a valid PredictionBase."""
|
||||
prediction_data = {
|
||||
"match_id": 1,
|
||||
"energy_score": "high",
|
||||
"confidence": "85%",
|
||||
"predicted_winner": "PSG",
|
||||
"created_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
prediction = PredictionBase(**prediction_data)
|
||||
|
||||
assert prediction.match_id == 1
|
||||
assert prediction.energy_score == "high"
|
||||
assert prediction.confidence == "85%"
|
||||
assert prediction.predicted_winner == "PSG"
|
||||
|
||||
def test_prediction_base_energy_score_too_long(self):
|
||||
"""Test that energy_score exceeds max length."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
PredictionBase(
|
||||
match_id=1,
|
||||
energy_score="A" * 51, # Too long
|
||||
confidence="85%",
|
||||
predicted_winner="PSG",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
assert "at most 50 characters" in str(exc_info.value).lower()
|
||||
|
||||
def test_prediction_base_confidence_too_long(self):
|
||||
"""Test that confidence exceeds max length."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
PredictionBase(
|
||||
match_id=1,
|
||||
energy_score="high",
|
||||
confidence="A" * 51, # Too long
|
||||
predicted_winner="PSG",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
assert "at most 50 characters" in str(exc_info.value).lower()
|
||||
|
||||
def test_prediction_base_predicted_winner_too_long(self):
|
||||
"""Test that predicted_winner exceeds max length."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
PredictionBase(
|
||||
match_id=1,
|
||||
energy_score="high",
|
||||
confidence="85%",
|
||||
predicted_winner="A" * 256, # Too long
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
assert "at most 255 characters" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestPredictionCreate:
|
||||
"""Test PredictionCreate schema."""
|
||||
|
||||
def test_prediction_create_valid(self):
|
||||
"""Test creating a valid PredictionCreate."""
|
||||
prediction_data = {
|
||||
"match_id": 1,
|
||||
"energy_score": "medium",
|
||||
"confidence": "70%",
|
||||
"predicted_winner": "Barcelona",
|
||||
"created_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
prediction = PredictionCreate(**prediction_data)
|
||||
|
||||
assert prediction.match_id == 1
|
||||
assert prediction.energy_score == "medium"
|
||||
assert prediction.confidence == "70%"
|
||||
assert prediction.predicted_winner == "Barcelona"
|
||||
|
||||
|
||||
class TestPredictionUpdate:
|
||||
"""Test PredictionUpdate schema."""
|
||||
|
||||
def test_prediction_update_partial(self):
|
||||
"""Test updating only some fields."""
|
||||
update_data = {
|
||||
"confidence": "90%"
|
||||
}
|
||||
|
||||
prediction_update = PredictionUpdate(**update_data)
|
||||
|
||||
assert prediction_update.confidence == "90%"
|
||||
assert prediction_update.energy_score is None
|
||||
assert prediction_update.predicted_winner is None
|
||||
|
||||
def test_prediction_update_all_fields(self):
|
||||
"""Test updating all fields."""
|
||||
update_data = {
|
||||
"energy_score": "very_high",
|
||||
"confidence": "95%",
|
||||
"predicted_winner": "Real Madrid"
|
||||
}
|
||||
|
||||
prediction_update = PredictionUpdate(**update_data)
|
||||
|
||||
assert prediction_update.energy_score == "very_high"
|
||||
assert prediction_update.confidence == "95%"
|
||||
assert prediction_update.predicted_winner == "Real Madrid"
|
||||
|
||||
def test_prediction_update_empty(self):
|
||||
"""Test that PredictionUpdate can be empty."""
|
||||
prediction_update = PredictionUpdate()
|
||||
|
||||
assert prediction_update.energy_score is None
|
||||
assert prediction_update.confidence is None
|
||||
assert prediction_update.predicted_winner is None
|
||||
|
||||
|
||||
class TestPredictionResponse:
|
||||
"""Test PredictionResponse schema."""
|
||||
|
||||
def test_prediction_response_from_dict(self):
|
||||
"""Test creating PredictionResponse from dictionary."""
|
||||
prediction_dict = {
|
||||
"id": 1,
|
||||
"match_id": 1,
|
||||
"energy_score": "high",
|
||||
"confidence": "85%",
|
||||
"predicted_winner": "Manchester City",
|
||||
"created_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
prediction = PredictionResponse(**prediction_dict)
|
||||
|
||||
assert prediction.id == 1
|
||||
assert prediction.match_id == 1
|
||||
assert prediction.energy_score == "high"
|
||||
assert prediction.confidence == "85%"
|
||||
assert prediction.predicted_winner == "Manchester City"
|
||||
|
||||
|
||||
class TestPredictionListResponse:
|
||||
"""Test PredictionListResponse schema."""
|
||||
|
||||
def test_prediction_list_response(self):
|
||||
"""Test creating a PredictionListResponse."""
|
||||
predictions_data = [
|
||||
{
|
||||
"id": 1,
|
||||
"match_id": 1,
|
||||
"energy_score": "high",
|
||||
"confidence": "85%",
|
||||
"predicted_winner": "PSG",
|
||||
"created_at": datetime.now(timezone.utc)
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"match_id": 2,
|
||||
"energy_score": "medium",
|
||||
"confidence": "70%",
|
||||
"predicted_winner": "Barcelona",
|
||||
"created_at": datetime.now(timezone.utc)
|
||||
}
|
||||
]
|
||||
|
||||
response = PredictionListResponse(data=predictions_data, count=2, meta={"page": 1})
|
||||
|
||||
assert response.count == 2
|
||||
assert len(response.data) == 2
|
||||
assert response.meta["page"] == 1
|
||||
|
||||
|
||||
class TestPredictionStatsResponse:
|
||||
"""Test PredictionStatsResponse schema."""
|
||||
|
||||
def test_prediction_stats_response(self):
|
||||
"""Test creating a PredictionStatsResponse."""
|
||||
stats = {
|
||||
"total_predictions": 20,
|
||||
"predictions_by_confidence": {"high": 10, "medium": 5, "low": 5},
|
||||
"predictions_by_energy_score": {"high": 8, "medium": 7, "low": 5},
|
||||
"avg_confidence": 75.5,
|
||||
"unique_matches_predicted": 15
|
||||
}
|
||||
|
||||
prediction_stats = PredictionStatsResponse(**stats)
|
||||
|
||||
assert prediction_stats.total_predictions == 20
|
||||
assert prediction_stats.predictions_by_confidence["high"] == 10
|
||||
assert prediction_stats.predictions_by_energy_score["medium"] == 7
|
||||
assert prediction_stats.avg_confidence == 75.5
|
||||
assert prediction_stats.unique_matches_predicted == 15
|
||||
258
backend/tests/test_prediction_service.py
Normal file
258
backend/tests/test_prediction_service.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Unit tests for Prediction Service.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import Base, engine, SessionLocal
|
||||
from app.models.match import Match
|
||||
from app.models.prediction import Prediction
|
||||
from app.services.prediction_service import PredictionService
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""Create a fresh database session for each test."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_match(db_session: Session) -> Match:
|
||||
"""Create a sample match for testing."""
|
||||
match = Match(
|
||||
home_team="PSG",
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.now(timezone.utc),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
db_session.add(match)
|
||||
db_session.commit()
|
||||
db_session.refresh(match)
|
||||
return match
|
||||
|
||||
|
||||
class TestPredictionServiceCreate:
|
||||
"""Test prediction service creation methods."""
|
||||
|
||||
def test_create_prediction_home_win(self, db_session: Session, sample_match: Match):
|
||||
"""Test creating a prediction where home team wins."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
prediction = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=65.0,
|
||||
away_energy=45.0,
|
||||
energy_score_label="high"
|
||||
)
|
||||
|
||||
assert prediction.id is not None
|
||||
assert prediction.match_id == sample_match.id
|
||||
assert prediction.predicted_winner == "PSG"
|
||||
assert prediction.energy_score == "high"
|
||||
assert "%" in prediction.confidence
|
||||
|
||||
def test_create_prediction_away_win(self, db_session: Session, sample_match: Match):
|
||||
"""Test creating a prediction where away team wins."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
prediction = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=35.0,
|
||||
away_energy=70.0
|
||||
)
|
||||
|
||||
assert prediction.predicted_winner == "Olympique de Marseille"
|
||||
assert prediction.id is not None
|
||||
|
||||
def test_create_prediction_draw(self, db_session: Session, sample_match: Match):
|
||||
"""Test creating a prediction for draw."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
prediction = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=50.0,
|
||||
away_energy=50.0
|
||||
)
|
||||
|
||||
assert prediction.predicted_winner == "Draw"
|
||||
assert prediction.confidence == "0.0%"
|
||||
|
||||
def test_create_prediction_with_default_energy_label(self, db_session: Session, sample_match: Match):
|
||||
"""Test creating prediction without energy score label."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# High energy prediction
|
||||
prediction1 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=80.0,
|
||||
away_energy=70.0
|
||||
)
|
||||
assert prediction1.energy_score in ["high", "very_high"]
|
||||
|
||||
# Low energy prediction
|
||||
prediction2 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=20.0,
|
||||
away_energy=10.0
|
||||
)
|
||||
assert prediction2.energy_score == "low"
|
||||
|
||||
def test_create_prediction_nonexistent_match(self, db_session: Session):
|
||||
"""Test creating prediction for non-existent match raises error."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="Match with id 999 not found"):
|
||||
service.create_prediction_for_match(
|
||||
match_id=999,
|
||||
home_energy=65.0,
|
||||
away_energy=45.0
|
||||
)
|
||||
|
||||
def test_create_prediction_negative_energy(self, db_session: Session, sample_match: Match):
|
||||
"""Test creating prediction with negative energy raises error."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be negative"):
|
||||
service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=-10.0,
|
||||
away_energy=45.0
|
||||
)
|
||||
|
||||
def test_create_prediction_invalid_energy_type(self, db_session: Session, sample_match: Match):
|
||||
"""Test creating prediction with invalid energy type raises error."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="must be numeric"):
|
||||
service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy="invalid",
|
||||
away_energy=45.0
|
||||
)
|
||||
|
||||
|
||||
class TestPredictionServiceRetrieve:
|
||||
"""Test prediction service retrieval methods."""
|
||||
|
||||
def test_get_prediction_by_id(self, db_session: Session, sample_match: Match):
|
||||
"""Test retrieving a prediction by ID."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
created = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=65.0,
|
||||
away_energy=45.0
|
||||
)
|
||||
|
||||
retrieved = service.get_prediction_by_id(created.id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == created.id
|
||||
assert retrieved.match_id == created.match_id
|
||||
|
||||
def test_get_prediction_by_id_not_found(self, db_session: Session):
|
||||
"""Test retrieving non-existent prediction returns None."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
retrieved = service.get_prediction_by_id(999)
|
||||
|
||||
assert retrieved is None
|
||||
|
||||
def test_get_predictions_for_match(self, db_session: Session, sample_match: Match):
|
||||
"""Test retrieving all predictions for a match."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create multiple predictions
|
||||
for i in range(3):
|
||||
service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=50.0 + i * 10,
|
||||
away_energy=40.0 + i * 5
|
||||
)
|
||||
|
||||
predictions = service.get_predictions_for_match(sample_match.id)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert all(p.match_id == sample_match.id for p in predictions)
|
||||
|
||||
def test_get_predictions_for_empty_match(self, db_session: Session, sample_match: Match):
|
||||
"""Test retrieving predictions for match with no predictions."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
predictions = service.get_predictions_for_match(sample_match.id)
|
||||
|
||||
assert len(predictions) == 0
|
||||
|
||||
def test_get_latest_prediction_for_match(self, db_session: Session, sample_match: Match):
|
||||
"""Test retrieving latest prediction for a match."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
# Create predictions at different times
|
||||
prediction1 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=50.0,
|
||||
away_energy=40.0
|
||||
)
|
||||
|
||||
# Small delay to ensure different timestamps
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
|
||||
prediction2 = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=60.0,
|
||||
away_energy=35.0
|
||||
)
|
||||
|
||||
latest = service.get_latest_prediction_for_match(sample_match.id)
|
||||
|
||||
assert latest.id == prediction2.id
|
||||
assert latest.id != prediction1.id
|
||||
|
||||
def test_get_latest_prediction_for_empty_match(self, db_session: Session, sample_match: Match):
|
||||
"""Test retrieving latest prediction for match with no predictions."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
latest = service.get_latest_prediction_for_match(sample_match.id)
|
||||
|
||||
assert latest is None
|
||||
|
||||
|
||||
class TestPredictionServiceDelete:
|
||||
"""Test prediction service delete methods."""
|
||||
|
||||
def test_delete_prediction(self, db_session: Session, sample_match: Match):
|
||||
"""Test deleting a prediction."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
prediction = service.create_prediction_for_match(
|
||||
match_id=sample_match.id,
|
||||
home_energy=65.0,
|
||||
away_energy=45.0
|
||||
)
|
||||
|
||||
deleted = service.delete_prediction(prediction.id)
|
||||
|
||||
assert deleted is True
|
||||
|
||||
# Verify prediction is gone
|
||||
retrieved = service.get_prediction_by_id(prediction.id)
|
||||
assert retrieved is None
|
||||
|
||||
def test_delete_prediction_not_found(self, db_session: Session):
|
||||
"""Test deleting non-existent prediction."""
|
||||
service = PredictionService(db_session)
|
||||
|
||||
deleted = service.delete_prediction(999)
|
||||
|
||||
assert deleted is False
|
||||
178
backend/tests/test_public_api.py
Normal file
178
backend/tests/test_public_api.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Tests for public API endpoints.
|
||||
|
||||
This module tests the public API endpoints for predictions and matches.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.main import app
|
||||
from app.database import get_db
|
||||
from app.models.prediction import Prediction
|
||||
from app.models.match import Match
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_match(db: Session):
|
||||
"""Create a sample match for testing."""
|
||||
match = Match(
|
||||
home_team="PSG",
|
||||
away_team="Olympique de Marseille",
|
||||
date=datetime.utcnow(),
|
||||
league="Ligue 1",
|
||||
status="scheduled"
|
||||
)
|
||||
db.add(match)
|
||||
db.commit()
|
||||
db.refresh(match)
|
||||
return match
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prediction(db: Session, sample_match: Match):
|
||||
"""Create a sample prediction for testing."""
|
||||
prediction = Prediction(
|
||||
match_id=sample_match.id,
|
||||
energy_score="high",
|
||||
confidence="70.5%",
|
||||
predicted_winner="PSG",
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
db.add(prediction)
|
||||
db.commit()
|
||||
db.refresh(prediction)
|
||||
return prediction
|
||||
|
||||
|
||||
class TestPublicPredictionsEndpoint:
|
||||
"""Tests for GET /api/public/v1/predictions endpoint."""
|
||||
|
||||
def test_get_public_predictions_success(self, sample_prediction: Prediction):
|
||||
"""Test successful retrieval of public predictions."""
|
||||
response = client.get("/api/public/v1/predictions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "meta" in data
|
||||
assert isinstance(data["data"], list)
|
||||
assert len(data["data"]) > 0
|
||||
assert data["meta"]["version"] == "v1"
|
||||
|
||||
def test_get_public_predictions_with_limit(self, sample_prediction: Prediction):
|
||||
"""Test retrieval of public predictions with limit parameter."""
|
||||
response = client.get("/api/public/v1/predictions?limit=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "meta" in data
|
||||
assert data["meta"]["limit"] == 5
|
||||
|
||||
def test_get_public_predictions_data_structure(self, sample_prediction: Prediction):
|
||||
"""Test that public predictions have correct structure without sensitive data."""
|
||||
response = client.get("/api/public/v1/predictions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
prediction = data["data"][0]
|
||||
|
||||
# Required fields
|
||||
assert "id" in prediction
|
||||
assert "match" in prediction
|
||||
assert "energy_score" in prediction
|
||||
assert "confidence" in prediction
|
||||
assert "predicted_winner" in prediction
|
||||
|
||||
# Match details
|
||||
match = prediction["match"]
|
||||
assert "id" in match
|
||||
assert "home_team" in match
|
||||
assert "away_team" in match
|
||||
assert "date" in match
|
||||
assert "league" in match
|
||||
assert "status" in match
|
||||
|
||||
def test_get_public_predictions_empty_database(self):
|
||||
"""Test retrieval when no predictions exist."""
|
||||
# Clear database for this test
|
||||
response = client.get("/api/public/v1/predictions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert isinstance(data["data"], list)
|
||||
|
||||
|
||||
class TestPublicMatchesEndpoint:
|
||||
"""Tests for GET /api/public/v1/matches endpoint."""
|
||||
|
||||
def test_get_public_matches_success(self, sample_match: Match):
|
||||
"""Test successful retrieval of public matches."""
|
||||
response = client.get("/api/public/v1/matches")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "meta" in data
|
||||
assert isinstance(data["data"], list)
|
||||
assert len(data["data"]) > 0
|
||||
assert data["meta"]["version"] == "v1"
|
||||
|
||||
def test_get_public_matches_with_filters(self, sample_match: Match):
|
||||
"""Test retrieval of public matches with league filter."""
|
||||
response = client.get(f"/api/public/v1/matches?league=Ligue%201")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "meta" in data
|
||||
|
||||
def test_get_public_matches_data_structure(self, sample_match: Match):
|
||||
"""Test that public matches have correct structure."""
|
||||
response = client.get("/api/public/v1/matches")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
match = data["data"][0]
|
||||
|
||||
# Required fields (public only, no sensitive data)
|
||||
assert "id" in match
|
||||
assert "home_team" in match
|
||||
assert "away_team" in match
|
||||
assert "date" in match
|
||||
assert "league" in match
|
||||
assert "status" in match
|
||||
|
||||
def test_get_public_matches_empty_database(self):
|
||||
"""Test retrieval when no matches exist."""
|
||||
response = client.get("/api/public/v1/matches")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert isinstance(data["data"], list)
|
||||
|
||||
|
||||
class TestPublicApiResponseFormat:
|
||||
"""Tests for standardized API response format."""
|
||||
|
||||
def test_response_format_success(self, sample_prediction: Prediction):
|
||||
"""Test that successful responses follow {data, meta} format."""
|
||||
response = client.get("/api/public/v1/predictions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "data" in data
|
||||
assert "meta" in data
|
||||
assert "timestamp" in data["meta"]
|
||||
assert "version" in data["meta"]
|
||||
|
||||
def test_response_meta_timestamp_format(self, sample_prediction: Prediction):
|
||||
"""Test that timestamp is in ISO 8601 format."""
|
||||
response = client.get("/api/public/v1/predictions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
timestamp = data["meta"]["timestamp"]
|
||||
|
||||
# ISO 8601 format check (simplified)
|
||||
assert "T" in timestamp
|
||||
assert "Z" in timestamp
|
||||
223
backend/tests/test_rabbitmq_client.py
Normal file
223
backend/tests/test_rabbitmq_client.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
Tests for RabbitMQ client.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import pika
|
||||
|
||||
from app.queues.rabbitmq_client import RabbitMQClient, create_rabbitmq_client
|
||||
|
||||
|
||||
class TestRabbitMQClient:
|
||||
"""Tests for RabbitMQClient class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RabbitMQ client initialization."""
|
||||
client = RabbitMQClient(
|
||||
rabbitmq_url="amqp://guest:guest@localhost:5672",
|
||||
prefetch_count=1
|
||||
)
|
||||
|
||||
assert client.rabbitmq_url == "amqp://guest:guest@localhost:5672"
|
||||
assert client.prefetch_count == 1
|
||||
assert client.connection is None
|
||||
assert client.channel is None
|
||||
assert client.queues == {
|
||||
'scraping_tasks': 'scraping_tasks',
|
||||
'sentiment_analysis_tasks': 'sentiment_analysis_tasks',
|
||||
'energy_calculation_tasks': 'energy_calculation_tasks',
|
||||
'results': 'results'
|
||||
}
|
||||
|
||||
@patch('pika.BlockingConnection')
|
||||
def test_connect_success(self, mock_connection):
|
||||
"""Test successful connection to RabbitMQ."""
|
||||
# Setup mocks
|
||||
mock_conn_instance = Mock()
|
||||
mock_channel_instance = Mock()
|
||||
mock_connection.return_value = mock_conn_instance
|
||||
mock_conn_instance.channel.return_value = mock_channel_instance
|
||||
|
||||
# Create client and connect
|
||||
client = RabbitMQClient()
|
||||
client.connect()
|
||||
|
||||
# Verify connection and channel created
|
||||
mock_connection.assert_called_once()
|
||||
mock_conn_instance.channel.assert_called_once()
|
||||
mock_channel_instance.basic_qos.assert_called_once_with(prefetch_count=1)
|
||||
|
||||
# Verify queues declared
|
||||
assert mock_channel_instance.queue_declare.call_count == 4
|
||||
|
||||
# Verify client state
|
||||
assert client.connection == mock_conn_instance
|
||||
assert client.channel == mock_channel_instance
|
||||
|
||||
@patch('pika.BlockingConnection')
|
||||
def test_publish_message(self, mock_connection):
|
||||
"""Test publishing a message to a queue."""
|
||||
# Setup mocks
|
||||
mock_conn_instance = Mock()
|
||||
mock_channel_instance = Mock()
|
||||
mock_connection.return_value = mock_conn_instance
|
||||
mock_conn_instance.channel.return_value = mock_channel_instance
|
||||
|
||||
# Create client and connect
|
||||
client = RabbitMQClient()
|
||||
client.connect()
|
||||
|
||||
# Publish message
|
||||
client.publish_message(
|
||||
queue_name='scraping_tasks',
|
||||
data={'match_id': 123, 'source': 'twitter'},
|
||||
event_type='scraping.task.created'
|
||||
)
|
||||
|
||||
# Verify publish called
|
||||
mock_channel_instance.basic_publish.assert_called_once()
|
||||
call_args = mock_channel_instance.basic_publish.call_args
|
||||
|
||||
# Check routing key
|
||||
assert call_args[1]['routing_key'] == 'scraping_tasks'
|
||||
|
||||
# Check message body is valid JSON
|
||||
message_body = call_args[1]['body']
|
||||
message = json.loads(message_body)
|
||||
|
||||
assert message['event'] == 'scraping.task.created'
|
||||
assert message['version'] == '1.0'
|
||||
assert message['data']['match_id'] == 123
|
||||
assert message['data']['source'] == 'twitter'
|
||||
assert 'timestamp' in message
|
||||
|
||||
# Check message properties
|
||||
properties = call_args[1]['properties']
|
||||
assert properties.delivery_mode == 2 # Persistent
|
||||
|
||||
@patch('pika.BlockingConnection')
|
||||
def test_consume_messages(self, mock_connection):
|
||||
"""Test consuming messages from a queue."""
|
||||
# Setup mocks
|
||||
mock_conn_instance = Mock()
|
||||
mock_channel_instance = Mock()
|
||||
mock_connection.return_value = mock_conn_instance
|
||||
mock_conn_instance.channel.return_value = mock_channel_instance
|
||||
|
||||
# Create client and connect
|
||||
client = RabbitMQClient()
|
||||
client.connect()
|
||||
|
||||
# Mock start_consuming to avoid blocking
|
||||
mock_channel_instance.start_consuming = Mock(side_effect=KeyboardInterrupt())
|
||||
|
||||
# Define callback
|
||||
callback = Mock()
|
||||
|
||||
# Consume messages
|
||||
try:
|
||||
client.consume_messages(queue_name='scraping_tasks', callback=callback)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify consume called
|
||||
mock_channel_instance.basic_consume.assert_called_once()
|
||||
call_args = mock_channel_instance.basic_consume.call_args
|
||||
|
||||
assert call_args[1]['queue'] == 'scraping_tasks'
|
||||
assert call_args[1]['auto_ack'] is False
|
||||
assert call_args[1]['on_message_callback'] == callback
|
||||
|
||||
# Verify start_consuming called
|
||||
mock_channel_instance.start_consuming.assert_called_once()
|
||||
|
||||
@patch('pika.BlockingConnection')
|
||||
def test_ack_message(self, mock_connection):
|
||||
"""Test acknowledging a message."""
|
||||
# Setup mocks
|
||||
mock_conn_instance = Mock()
|
||||
mock_channel_instance = Mock()
|
||||
mock_connection.return_value = mock_conn_instance
|
||||
mock_conn_instance.channel.return_value = mock_channel_instance
|
||||
|
||||
# Create client and connect
|
||||
client = RabbitMQClient()
|
||||
client.connect()
|
||||
|
||||
# Acknowledge message
|
||||
client.ack_message(delivery_tag=123)
|
||||
|
||||
# Verify ack called
|
||||
mock_channel_instance.basic_ack.assert_called_once_with(delivery_tag=123)
|
||||
|
||||
@patch('pika.BlockingConnection')
|
||||
def test_reject_message(self, mock_connection):
|
||||
"""Test rejecting a message."""
|
||||
# Setup mocks
|
||||
mock_conn_instance = Mock()
|
||||
mock_channel_instance = Mock()
|
||||
mock_connection.return_value = mock_conn_instance
|
||||
mock_conn_instance.channel.return_value = mock_channel_instance
|
||||
|
||||
# Create client and connect
|
||||
client = RabbitMQClient()
|
||||
client.connect()
|
||||
|
||||
# Reject message without requeue
|
||||
client.reject_message(delivery_tag=123, requeue=False)
|
||||
|
||||
# Verify reject called
|
||||
mock_channel_instance.basic_reject.assert_called_once_with(
|
||||
delivery_tag=123,
|
||||
requeue=False
|
||||
)
|
||||
|
||||
@patch('pika.BlockingConnection')
|
||||
def test_close_connection(self, mock_connection):
|
||||
"""Test closing connection to RabbitMQ."""
|
||||
# Setup mocks
|
||||
mock_conn_instance = Mock()
|
||||
mock_channel_instance = Mock()
|
||||
mock_connection.return_value = mock_conn_instance
|
||||
mock_conn_instance.channel.return_value = mock_channel_instance
|
||||
|
||||
# Create client and connect
|
||||
client = RabbitMQClient()
|
||||
client.connect()
|
||||
|
||||
# Close connection
|
||||
client.close()
|
||||
|
||||
# Verify channel and connection closed
|
||||
mock_channel_instance.close.assert_called_once()
|
||||
mock_conn_instance.close.assert_called_once()
|
||||
|
||||
|
||||
class TestCreateRabbitMQClient:
|
||||
"""Tests for create_rabbitmq_client factory function."""
|
||||
|
||||
def test_create_with_defaults(self):
|
||||
"""Test creating client with default parameters."""
|
||||
client = create_rabbitmq_client()
|
||||
|
||||
assert client.rabbitmq_url == "amqp://guest:guest@localhost:5672"
|
||||
assert client.prefetch_count == 1
|
||||
assert isinstance(client, RabbitMQClient)
|
||||
|
||||
def test_create_with_custom_url(self):
|
||||
"""Test creating client with custom URL."""
|
||||
client = create_rabbitmq_client(
|
||||
rabbitmq_url="amqp://user:pass@remote:5672"
|
||||
)
|
||||
|
||||
assert client.rabbitmq_url == "amqp://user:pass@remote:5672"
|
||||
assert client.prefetch_count == 1
|
||||
|
||||
def test_create_with_custom_prefetch(self):
|
||||
"""Test creating client with custom prefetch count."""
|
||||
client = create_rabbitmq_client(prefetch_count=5)
|
||||
|
||||
assert client.rabbitmq_url == "amqp://guest:guest@localhost:5672"
|
||||
assert client.prefetch_count == 5
|
||||
146
backend/tests/test_rabbitmq_consumers.py
Normal file
146
backend/tests/test_rabbitmq_consumers.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Tests for RabbitMQ message consumers.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.queues.consumers import (
|
||||
consume_scraping_tasks,
|
||||
consume_sentiment_analysis_tasks,
|
||||
consume_energy_calculation_tasks,
|
||||
consume_results
|
||||
)
|
||||
|
||||
|
||||
class TestConsumeScrapingTasks:
|
||||
"""Tests for consume_scraping_tasks function."""
|
||||
|
||||
@patch('app.queues.consumers.client')
|
||||
@patch('app.queues.consumers.consume_messages')
|
||||
def test_consume_scraping_tasks_success(self, mock_consume, mock_client):
|
||||
"""Test successful scraping task consumption."""
|
||||
# Setup mocks
|
||||
mock_db_session = Mock(spec=Session)
|
||||
mock_db_factory = Mock(return_value=mock_db_session)
|
||||
mock_callback = Mock(return_value={'collected_count': 50})
|
||||
|
||||
# Call consume
|
||||
consume_scraping_tasks(
|
||||
client=mock_client,
|
||||
callback=mock_callback,
|
||||
db_session_factory=mock_db_factory
|
||||
)
|
||||
|
||||
# Verify consume_messages called
|
||||
mock_consume.assert_called_once()
|
||||
call_args = mock_consume.call_args
|
||||
|
||||
assert call_args[0][0] == 'scraping_tasks'
|
||||
assert call_args[1]['callback'] is not None
|
||||
|
||||
@patch('app.queues.consumers.client')
|
||||
@patch('app.queues.consumers.consume_messages')
|
||||
def test_consume_scraping_tasks_error_handling(self, mock_consume, mock_client):
|
||||
"""Test error handling in scraping task consumption."""
|
||||
mock_db_session = Mock(spec=Session)
|
||||
mock_db_factory = Mock(return_value=mock_db_session)
|
||||
mock_callback = Mock(side_effect=Exception("Processing error"))
|
||||
|
||||
# Get the callback function
|
||||
def capture_callback(ch, method, properties, body):
|
||||
message = json.loads(body)
|
||||
mock_callback(message.get('data', {}), mock_db_session)
|
||||
|
||||
# Call consume
|
||||
consume_scraping_tasks(
|
||||
client=mock_client,
|
||||
callback=mock_callback,
|
||||
db_session_factory=mock_db_factory
|
||||
)
|
||||
|
||||
# Verify consume_messages called
|
||||
mock_consume.assert_called_once()
|
||||
|
||||
|
||||
class TestConsumeSentimentAnalysisTasks:
|
||||
"""Tests for consume_sentiment_analysis_tasks function."""
|
||||
|
||||
@patch('app.queues.consumers.client')
|
||||
@patch('app.queues.consumers.consume_messages')
|
||||
def test_consume_sentiment_analysis_tasks_success(self, mock_consume, mock_client):
|
||||
"""Test successful sentiment analysis task consumption."""
|
||||
# Setup mocks
|
||||
mock_db_session = Mock(spec=Session)
|
||||
mock_db_factory = Mock(return_value=mock_db_session)
|
||||
mock_callback = Mock(return_value={'analyzed_count': 100})
|
||||
|
||||
# Call consume
|
||||
consume_sentiment_analysis_tasks(
|
||||
client=mock_client,
|
||||
callback=mock_callback,
|
||||
db_session_factory=mock_db_factory
|
||||
)
|
||||
|
||||
# Verify consume_messages called
|
||||
mock_consume.assert_called_once()
|
||||
call_args = mock_consume.call_args
|
||||
|
||||
assert call_args[0][0] == 'sentiment_analysis_tasks'
|
||||
assert call_args[1]['callback'] is not None
|
||||
|
||||
|
||||
class TestConsumeEnergyCalculationTasks:
|
||||
"""Tests for consume_energy_calculation_tasks function."""
|
||||
|
||||
@patch('app.queues.consumers.client')
|
||||
@patch('app.queues.consumers.consume_messages')
|
||||
def test_consume_energy_calculation_tasks_success(self, mock_consume, mock_client):
|
||||
"""Test successful energy calculation task consumption."""
|
||||
# Setup mocks
|
||||
mock_db_session = Mock(spec=Session)
|
||||
mock_db_factory = Mock(return_value=mock_db_session)
|
||||
mock_callback = Mock(return_value={
|
||||
'energy_score': 75.5,
|
||||
'confidence': 0.82
|
||||
})
|
||||
|
||||
# Call consume
|
||||
consume_energy_calculation_tasks(
|
||||
client=mock_client,
|
||||
callback=mock_callback,
|
||||
db_session_factory=mock_db_factory
|
||||
)
|
||||
|
||||
# Verify consume_messages called
|
||||
mock_consume.assert_called_once()
|
||||
call_args = mock_consume.call_args
|
||||
|
||||
assert call_args[0][0] == 'energy_calculation_tasks'
|
||||
assert call_args[1]['callback'] is not None
|
||||
|
||||
|
||||
class TestConsumeResults:
|
||||
"""Tests for consume_results function."""
|
||||
|
||||
@patch('app.queues.consumers.client')
|
||||
@patch('app.queues.consumers.consume_messages')
|
||||
def test_consume_results_success(self, mock_consume, mock_client):
|
||||
"""Test successful result consumption."""
|
||||
# Setup mocks
|
||||
mock_callback = Mock()
|
||||
|
||||
# Call consume
|
||||
consume_results(
|
||||
client=mock_client,
|
||||
callback=mock_callback
|
||||
)
|
||||
|
||||
# Verify consume_messages called
|
||||
mock_consume.assert_called_once()
|
||||
call_args = mock_consume.call_args
|
||||
|
||||
assert call_args[0][0] == 'results'
|
||||
assert call_args[1]['callback'] is not None
|
||||
266
backend/tests/test_rabbitmq_producers.py
Normal file
266
backend/tests/test_rabbitmq_producers.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Tests for RabbitMQ message producers.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.queues.producers import (
|
||||
publish_scraping_task,
|
||||
publish_sentiment_analysis_task,
|
||||
publish_energy_calculation_task,
|
||||
publish_result,
|
||||
publish_scraping_result,
|
||||
publish_sentiment_analysis_result,
|
||||
publish_energy_calculation_result
|
||||
)
|
||||
|
||||
|
||||
class TestPublishScrapingTask:
|
||||
"""Tests for publish_scraping_task function."""
|
||||
|
||||
@patch('app.queues.producers.datetime')
|
||||
def test_publish_scraping_task_twitter(self, mock_datetime):
|
||||
"""Test publishing a Twitter scraping task."""
|
||||
# Mock datetime
|
||||
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
|
||||
|
||||
# Create mock client
|
||||
mock_client = Mock()
|
||||
mock_client.publish_message = Mock()
|
||||
|
||||
# Publish task
|
||||
publish_scraping_task(
|
||||
client=mock_client,
|
||||
match_id=123,
|
||||
source='twitter',
|
||||
keywords=['#MatchName', 'team1 vs team2'],
|
||||
priority='normal'
|
||||
)
|
||||
|
||||
# Verify publish_message called
|
||||
mock_client.publish_message.assert_called_once()
|
||||
call_args = mock_client.publish_message.call_args
|
||||
|
||||
# Check parameters
|
||||
assert call_args[1]['queue_name'] == 'scraping_tasks'
|
||||
assert call_args[1]['event_type'] == 'scraping.task.created'
|
||||
|
||||
# Check task data
|
||||
task_data = call_args[1]['data']
|
||||
assert task_data['task_type'] == 'scraping'
|
||||
assert task_data['match_id'] == 123
|
||||
assert task_data['source'] == 'twitter'
|
||||
assert task_data['keywords'] == ['#MatchName', 'team1 vs team2']
|
||||
assert task_data['priority'] == 'normal'
|
||||
assert 'created_at' in task_data
|
||||
|
||||
@patch('app.queues.producers.datetime')
|
||||
def test_publish_scraping_task_vip(self, mock_datetime):
|
||||
"""Test publishing a VIP scraping task."""
|
||||
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.publish_message = Mock()
|
||||
|
||||
publish_scraping_task(
|
||||
client=mock_client,
|
||||
match_id=456,
|
||||
source='reddit',
|
||||
keywords=['Ligue1'],
|
||||
priority='vip'
|
||||
)
|
||||
|
||||
call_args = mock_client.publish_message.call_args
|
||||
task_data = call_args[1]['data']
|
||||
|
||||
assert task_data['priority'] == 'vip'
|
||||
assert task_data['source'] == 'reddit'
|
||||
|
||||
|
||||
class TestPublishSentimentAnalysisTask:
|
||||
"""Tests for publish_sentiment_analysis_task function."""
|
||||
|
||||
@patch('app.queues.producers.datetime')
|
||||
def test_publish_sentiment_analysis_task_twitter(self, mock_datetime):
|
||||
"""Test publishing a Twitter sentiment analysis task."""
|
||||
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.publish_message = Mock()
|
||||
|
||||
publish_sentiment_analysis_task(
|
||||
client=mock_client,
|
||||
match_id=123,
|
||||
source='twitter',
|
||||
entity_ids=['tweet1', 'tweet2', 'tweet3']
|
||||
)
|
||||
|
||||
call_args = mock_client.publish_message.call_args
|
||||
assert call_args[1]['queue_name'] == 'sentiment_analysis_tasks'
|
||||
assert call_args[1]['event_type'] == 'sentiment_analysis.task.created'
|
||||
|
||||
task_data = call_args[1]['data']
|
||||
assert task_data['task_type'] == 'sentiment_analysis'
|
||||
assert task_data['match_id'] == 123
|
||||
assert task_data['source'] == 'twitter'
|
||||
assert task_data['entity_ids'] == ['tweet1', 'tweet2', 'tweet3']
|
||||
assert task_data['texts'] == []
|
||||
assert 'created_at' in task_data
|
||||
|
||||
@patch('app.queues.producers.datetime')
|
||||
def test_publish_sentiment_analysis_task_with_texts(self, mock_datetime):
|
||||
"""Test publishing sentiment analysis task with texts."""
|
||||
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.publish_message = Mock()
|
||||
|
||||
texts = ["Great match!", "Amazing goal", "What a game"]
|
||||
|
||||
publish_sentiment_analysis_task(
|
||||
client=mock_client,
|
||||
match_id=123,
|
||||
source='reddit',
|
||||
entity_ids=['post1', 'post2', 'post3'],
|
||||
texts=texts
|
||||
)
|
||||
|
||||
call_args = mock_client.publish_message.call_args
|
||||
task_data = call_args[1]['data']
|
||||
|
||||
assert task_data['texts'] == texts
|
||||
assert task_data['source'] == 'reddit'
|
||||
|
||||
|
||||
class TestPublishEnergyCalculationTask:
|
||||
"""Tests for publish_energy_calculation_task function."""
|
||||
|
||||
@patch('app.queues.producers.datetime')
|
||||
def test_publish_energy_calculation_task(self, mock_datetime):
|
||||
"""Test publishing an energy calculation task."""
|
||||
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.publish_message = Mock()
|
||||
|
||||
twitter_sentiments = [
|
||||
{'compound': 0.5, 'sentiment': 'positive'},
|
||||
{'compound': -0.3, 'sentiment': 'negative'}
|
||||
]
|
||||
|
||||
publish_energy_calculation_task(
|
||||
client=mock_client,
|
||||
match_id=123,
|
||||
team_id=456,
|
||||
twitter_sentiments=twitter_sentiments
|
||||
)
|
||||
|
||||
call_args = mock_client.publish_message.call_args
|
||||
assert call_args[1]['queue_name'] == 'energy_calculation_tasks'
|
||||
assert call_args[1]['event_type'] == 'energy_calculation.task.created'
|
||||
|
||||
task_data = call_args[1]['data']
|
||||
assert task_data['task_type'] == 'energy_calculation'
|
||||
assert task_data['match_id'] == 123
|
||||
assert task_data['team_id'] == 456
|
||||
assert task_data['twitter_sentiments'] == twitter_sentiments
|
||||
assert task_data['reddit_sentiments'] == []
|
||||
assert task_data['rss_sentiments'] == []
|
||||
assert task_data['tweets_with_timestamps'] == []
|
||||
assert 'created_at' in task_data
|
||||
|
||||
|
||||
class TestPublishResult:
|
||||
"""Tests for publish_result function."""
|
||||
|
||||
@patch('app.queues.producers.datetime')
|
||||
def test_publish_scraping_result_wrapper(self, mock_datetime):
|
||||
"""Test publishing a scraping result."""
|
||||
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.publish_message = Mock()
|
||||
|
||||
publish_scraping_result(
|
||||
client=mock_client,
|
||||
match_id=123,
|
||||
source='twitter',
|
||||
collected_count=100,
|
||||
metadata={'keywords': ['#MatchName']}
|
||||
)
|
||||
|
||||
call_args = mock_client.publish_message.call_args
|
||||
assert call_args[1]['queue_name'] == 'results'
|
||||
assert call_args[1]['event_type'] == 'result.published'
|
||||
|
||||
result_data = call_args[1]['data']
|
||||
assert result_data['result_type'] == 'scraping'
|
||||
assert result_data['data']['match_id'] == 123
|
||||
assert result_data['data']['source'] == 'twitter'
|
||||
assert result_data['data']['collected_count'] == 100
|
||||
assert result_data['data']['status'] == 'success'
|
||||
assert result_data['data']['metadata'] == {'keywords': ['#MatchName']}
|
||||
|
||||
@patch('app.queues.producers.datetime')
|
||||
def test_publish_sentiment_analysis_result_wrapper(self, mock_datetime):
|
||||
"""Test publishing a sentiment analysis result."""
|
||||
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.publish_message = Mock()
|
||||
|
||||
metrics = {
|
||||
'total_count': 100,
|
||||
'positive_count': 60,
|
||||
'negative_count': 20,
|
||||
'neutral_count': 20,
|
||||
'average_compound': 0.35
|
||||
}
|
||||
|
||||
publish_sentiment_analysis_result(
|
||||
client=mock_client,
|
||||
match_id=123,
|
||||
source='twitter',
|
||||
analyzed_count=100,
|
||||
metrics=metrics
|
||||
)
|
||||
|
||||
call_args = mock_client.publish_message.call_args
|
||||
result_data = call_args[1]['data']
|
||||
|
||||
assert result_data['result_type'] == 'sentiment'
|
||||
assert result_data['data']['match_id'] == 123
|
||||
assert result_data['data']['source'] == 'twitter'
|
||||
assert result_data['data']['analyzed_count'] == 100
|
||||
assert result_data['data']['metrics'] == metrics
|
||||
assert result_data['data']['status'] == 'success'
|
||||
|
||||
@patch('app.queues.producers.datetime')
|
||||
def test_publish_energy_calculation_result_wrapper(self, mock_datetime):
|
||||
"""Test publishing an energy calculation result."""
|
||||
mock_datetime.utcnow.return_value.isoformat.return_value = "2026-01-17T10:00:00"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.publish_message = Mock()
|
||||
|
||||
publish_energy_calculation_result(
|
||||
client=mock_client,
|
||||
match_id=123,
|
||||
team_id=456,
|
||||
energy_score=78.5,
|
||||
confidence=0.82,
|
||||
sources_used=['twitter', 'reddit']
|
||||
)
|
||||
|
||||
call_args = mock_client.publish_message.call_args
|
||||
result_data = call_args[1]['data']
|
||||
|
||||
assert result_data['result_type'] == 'energy'
|
||||
assert result_data['data']['match_id'] == 123
|
||||
assert result_data['data']['team_id'] == 456
|
||||
assert result_data['data']['energy_score'] == 78.5
|
||||
assert result_data['data']['confidence'] == 0.82
|
||||
assert result_data['data']['sources_used'] == ['twitter', 'reddit']
|
||||
assert result_data['data']['status'] == 'success'
|
||||
381
backend/tests/test_reddit_scraper.py
Normal file
381
backend/tests/test_reddit_scraper.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Unit tests for Reddit scraper.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import praw.exceptions
|
||||
|
||||
from app.scrapers.reddit_scraper import (
|
||||
RedditScraper,
|
||||
RedditPostData,
|
||||
RedditCommentData,
|
||||
create_reddit_scraper
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_praw():
|
||||
"""Mock praw.Reddit."""
|
||||
with patch('app.scrapers.reddit_scraper.praw.Reddit') as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client_id():
|
||||
"""Test Reddit client ID."""
|
||||
return "test_client_id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client_secret():
|
||||
"""Test Reddit client secret."""
|
||||
return "test_client_secret"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_scraper(test_client_id, test_client_secret):
|
||||
"""Create test scraper instance."""
|
||||
with patch('app.scrapers.reddit_scraper.praw.Reddit'):
|
||||
scraper = RedditScraper(
|
||||
client_id=test_client_id,
|
||||
client_secret=test_client_secret,
|
||||
subreddits=["soccer", "football"],
|
||||
max_posts_per_subreddit=100,
|
||||
max_comments_per_post=50
|
||||
)
|
||||
return scraper
|
||||
|
||||
|
||||
class TestRedditPostData:
|
||||
"""Test RedditPostData dataclass."""
|
||||
|
||||
def test_reddit_post_data_creation(self):
|
||||
"""Test creating RedditPostData instance."""
|
||||
post = RedditPostData(
|
||||
post_id="abc123",
|
||||
title="Test post title",
|
||||
text="Test post content",
|
||||
upvotes=100,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
match_id=1,
|
||||
subreddit="soccer"
|
||||
)
|
||||
|
||||
assert post.post_id == "abc123"
|
||||
assert post.title == "Test post title"
|
||||
assert post.text == "Test post content"
|
||||
assert post.upvotes == 100
|
||||
assert post.match_id == 1
|
||||
assert post.subreddit == "soccer"
|
||||
assert post.source == "reddit"
|
||||
|
||||
|
||||
class TestRedditCommentData:
|
||||
"""Test RedditCommentData dataclass."""
|
||||
|
||||
def test_reddit_comment_data_creation(self):
|
||||
"""Test creating RedditCommentData instance."""
|
||||
comment = RedditCommentData(
|
||||
comment_id="def456",
|
||||
post_id="abc123",
|
||||
text="Test comment content",
|
||||
upvotes=50,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
assert comment.comment_id == "def456"
|
||||
assert comment.post_id == "abc123"
|
||||
assert comment.text == "Test comment content"
|
||||
assert comment.upvotes == 50
|
||||
assert comment.source == "reddit"
|
||||
|
||||
|
||||
class TestRedditScraper:
|
||||
"""Test RedditScraper class."""
|
||||
|
||||
def test_scraper_initialization(self, test_client_id, test_client_secret):
|
||||
"""Test scraper initialization."""
|
||||
with patch('app.scrapers.reddit_scraper.praw.Reddit'):
|
||||
scraper = RedditScraper(
|
||||
client_id=test_client_id,
|
||||
client_secret=test_client_secret,
|
||||
subreddits=["soccer", "football"],
|
||||
max_posts_per_subreddit=100,
|
||||
max_comments_per_post=50
|
||||
)
|
||||
|
||||
assert scraper.client_id == test_client_id
|
||||
assert scraper.client_secret == test_client_secret
|
||||
assert scraper.subreddits == ["soccer", "football"]
|
||||
assert scraper.max_posts_per_subreddit == 100
|
||||
assert scraper.max_comments_per_post == 50
|
||||
|
||||
def test_verify_authentication_success(self, mock_praw, caplog):
|
||||
"""Test successful authentication verification."""
|
||||
mock_reddit = Mock()
|
||||
mock_reddit.user.me.return_value = Mock(name="test_user")
|
||||
mock_praw.return_value = mock_reddit
|
||||
|
||||
with caplog.at_level("INFO"):
|
||||
scraper = RedditScraper(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
subreddits=["soccer"]
|
||||
)
|
||||
|
||||
assert "✅ Reddit API authenticated successfully" in caplog.text
|
||||
|
||||
def test_verify_authentication_failure(self, mock_praw, caplog):
|
||||
"""Test failed authentication verification."""
|
||||
mock_reddit = Mock()
|
||||
mock_reddit.user.me.side_effect = Exception("Auth failed")
|
||||
mock_praw.return_value = mock_reddit
|
||||
|
||||
with pytest.raises(Exception, match="Reddit API authentication failed"):
|
||||
RedditScraper(
|
||||
client_id="invalid_id",
|
||||
client_secret="invalid_secret",
|
||||
subreddits=["soccer"]
|
||||
)
|
||||
|
||||
def test_scrape_posts_empty(self, test_scraper, mock_praw, caplog):
|
||||
"""Test scraping posts with no results."""
|
||||
mock_subreddit = Mock()
|
||||
mock_subreddit.new.return_value = []
|
||||
mock_reddit = Mock()
|
||||
mock_reddit.subreddit.return_value = mock_subreddit
|
||||
mock_praw.return_value = mock_reddit
|
||||
|
||||
with caplog.at_level("INFO"):
|
||||
result = test_scraper.scrape_posts(
|
||||
subreddit="soccer",
|
||||
match_id=1,
|
||||
keywords=["test"]
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert "ℹ️ No posts found" in caplog.text
|
||||
|
||||
def test_scrape_posts_success(self, test_scraper, mock_praw):
|
||||
"""Test successful post scraping."""
|
||||
# Mock post
|
||||
mock_post = Mock()
|
||||
mock_post.id = "abc123"
|
||||
mock_post.title = "Test match discussion"
|
||||
mock_post.selftext = "Great match today!"
|
||||
mock_post.score = 100
|
||||
mock_post.created_utc = 1700000000.0
|
||||
|
||||
mock_subreddit = Mock()
|
||||
mock_subreddit.new.return_value = [mock_post]
|
||||
mock_reddit = Mock()
|
||||
mock_reddit.subreddit.return_value = mock_subreddit
|
||||
mock_praw.return_value = mock_reddit
|
||||
|
||||
result = test_scraper.scrape_posts(
|
||||
subreddit="soccer",
|
||||
match_id=1,
|
||||
keywords=["test"]
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].post_id == "abc123"
|
||||
assert result[0].title == "Test match discussion"
|
||||
assert result[0].upvotes == 100
|
||||
|
||||
def test_scrape_comments_success(self, test_scraper, mock_praw):
|
||||
"""Test successful comment scraping."""
|
||||
# Mock comment
|
||||
mock_comment = Mock()
|
||||
mock_comment.id = "def456"
|
||||
mock_comment.body = "Great goal!"
|
||||
mock_comment.score = 50
|
||||
mock_comment.created_utc = 1700000000.0
|
||||
|
||||
mock_post = Mock()
|
||||
mock_post.comments.list.return_value = [mock_comment]
|
||||
|
||||
result = test_scraper.scrape_comments(post_id="abc123", post=mock_post)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].comment_id == "def456"
|
||||
assert result[0].text == "Great goal!"
|
||||
assert result[0].upvotes == 50
|
||||
|
||||
def test_save_posts_to_db(self, test_scraper, mock_praw):
|
||||
"""Test saving posts to database."""
|
||||
from app.models.reddit_post import RedditPost
|
||||
|
||||
# Mock posts data
|
||||
posts_data = [
|
||||
RedditPostData(
|
||||
post_id="abc123",
|
||||
title="Test post",
|
||||
text="Test content",
|
||||
upvotes=100,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
match_id=1,
|
||||
subreddit="soccer"
|
||||
)
|
||||
]
|
||||
|
||||
# Mock database session
|
||||
mock_db = Mock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Mock model
|
||||
mock_post_instance = Mock()
|
||||
mock_post_instance.id = 1
|
||||
|
||||
with patch('app.scrapers.reddit_scraper.RedditPost', return_value=mock_post_instance):
|
||||
test_scraper.save_posts_to_db(posts_data, mock_db)
|
||||
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_save_comments_to_db(self, test_scraper):
|
||||
"""Test saving comments to database."""
|
||||
# Mock comments data
|
||||
comments_data = [
|
||||
RedditCommentData(
|
||||
comment_id="def456",
|
||||
post_id="abc123",
|
||||
text="Test comment",
|
||||
upvotes=50,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
]
|
||||
|
||||
# Mock database session
|
||||
mock_db = Mock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Mock model
|
||||
mock_comment_instance = Mock()
|
||||
mock_comment_instance.id = 1
|
||||
|
||||
with patch('app.scrapers.reddit_scraper.RedditComment', return_value=mock_comment_instance):
|
||||
test_scraper.save_comments_to_db(comments_data, mock_db)
|
||||
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_save_posts_to_db_existing(self, test_scraper):
|
||||
"""Test saving posts when post already exists."""
|
||||
posts_data = [
|
||||
RedditPostData(
|
||||
post_id="abc123",
|
||||
title="Test post",
|
||||
text="Test content",
|
||||
upvotes=100,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
match_id=1,
|
||||
subreddit="soccer"
|
||||
)
|
||||
]
|
||||
|
||||
mock_db = Mock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = Mock()
|
||||
|
||||
test_scraper.save_posts_to_db(posts_data, mock_db)
|
||||
|
||||
# Should not add or commit if post already exists
|
||||
mock_db.add.assert_not_called()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_scrape_posts_api_error_continues(self, test_scraper, mock_praw, caplog):
|
||||
"""Test that scraper continues on API error."""
|
||||
mock_subreddit = Mock()
|
||||
mock_subreddit.new.side_effect = praw.exceptions.PRAWException("API Error")
|
||||
mock_reddit = Mock()
|
||||
mock_reddit.subreddit.return_value = mock_subreddit
|
||||
mock_praw.return_value = mock_reddit
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
result = test_scraper.scrape_posts(
|
||||
subreddit="soccer",
|
||||
match_id=1,
|
||||
keywords=["test"]
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert "Reddit API error" in caplog.text
|
||||
|
||||
def test_scrape_comments_api_error_continues(self, test_scraper, caplog):
|
||||
"""Test that scraper continues on comment API error."""
|
||||
mock_post = Mock()
|
||||
mock_post.comments.list.side_effect = praw.exceptions.PRAWException("API Error")
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
result = test_scraper.scrape_comments(post_id="abc123", post=mock_post)
|
||||
|
||||
assert result == []
|
||||
assert "Reddit API error" in caplog.text
|
||||
|
||||
def test_save_posts_db_error_rollback(self, test_scraper, caplog):
|
||||
"""Test that database errors trigger rollback."""
|
||||
posts_data = [
|
||||
RedditPostData(
|
||||
post_id="abc123",
|
||||
title="Test post",
|
||||
text="Test content",
|
||||
upvotes=100,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
match_id=1,
|
||||
subreddit="soccer"
|
||||
)
|
||||
]
|
||||
|
||||
mock_db = Mock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
mock_db.commit.side_effect = Exception("Database error")
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
with pytest.raises(Exception):
|
||||
test_scraper.save_posts_to_db(posts_data, mock_db)
|
||||
|
||||
mock_db.rollback.assert_called_once()
|
||||
assert "Failed to save Reddit posts" in caplog.text
|
||||
|
||||
def test_scrape_reddit_match_continues_on_subreddit_error(
|
||||
self, test_scraper, mock_praw, caplog
|
||||
):
|
||||
"""Test that scraper continues with other subreddits on error."""
|
||||
# Mock first subreddit to error, second to succeed
|
||||
mock_subreddit1 = Mock()
|
||||
mock_subreddit1.new.side_effect = praw.exceptions.PRAWException("API Error")
|
||||
mock_subreddit2 = Mock()
|
||||
mock_subreddit2.new.return_value = []
|
||||
|
||||
mock_reddit = Mock()
|
||||
mock_reddit.subreddit.side_effect = [mock_subreddit1, mock_subreddit2]
|
||||
mock_praw.return_value = mock_reddit
|
||||
|
||||
test_scraper.subreddits = ["soccer", "football"]
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
result = test_scraper.scrape_reddit_match(match_id=1)
|
||||
|
||||
assert result['posts'] == []
|
||||
assert result['comments'] == []
|
||||
assert "Continuing with other sources" in caplog.text
|
||||
|
||||
|
||||
class TestCreateRedditScraper:
|
||||
"""Test create_reddit_scraper factory function."""
|
||||
|
||||
def test_factory_function(self, test_client_id, test_client_secret):
|
||||
"""Test factory function creates scraper."""
|
||||
with patch('app.scrapers.reddit_scraper.RedditScraper') as MockScraper:
|
||||
mock_instance = Mock()
|
||||
MockScraper.return_value = mock_instance
|
||||
|
||||
result = create_reddit_scraper(
|
||||
client_id=test_client_id,
|
||||
client_secret=test_client_secret
|
||||
)
|
||||
|
||||
MockScraper.assert_called_once()
|
||||
assert result == mock_instance
|
||||
369
backend/tests/test_rss_scraper.py
Normal file
369
backend/tests/test_rss_scraper.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
Tests for RSS scraper module.
|
||||
|
||||
This test suite validates the RSS scraper functionality including
|
||||
parsing, filtering, error handling, and database operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from app.scrapers.rss_scraper import (
|
||||
RSSScraper,
|
||||
RSSArticleData,
|
||||
create_rss_scraper
|
||||
)
|
||||
|
||||
|
||||
class TestRSSScraperInit:
|
||||
"""Tests for RSS scraper initialization."""
|
||||
|
||||
def test_init_default_sources(self):
|
||||
"""Test initialization with default RSS sources."""
|
||||
scraper = RSSScraper()
|
||||
|
||||
assert len(scraper.rss_sources) == 4
|
||||
assert scraper.timeout == 30
|
||||
assert scraper.max_articles_per_source == 100
|
||||
assert len(scraper.keywords) > 0
|
||||
|
||||
def test_init_custom_sources(self):
|
||||
"""Test initialization with custom RSS sources."""
|
||||
custom_sources = ["http://example.com/rss"]
|
||||
scraper = RSSScraper(rss_sources=custom_sources)
|
||||
|
||||
assert scraper.rss_sources == custom_sources
|
||||
assert len(scraper.rss_sources) == 1
|
||||
|
||||
def test_init_custom_keywords(self):
|
||||
"""Test initialization with custom keywords."""
|
||||
custom_keywords = ["football", "soccer"]
|
||||
scraper = RSSScraper(keywords=custom_keywords)
|
||||
|
||||
assert scraper.keywords == custom_keywords
|
||||
|
||||
def test_init_custom_timeout(self):
|
||||
"""Test initialization with custom timeout."""
|
||||
scraper = RSSScraper(timeout=60)
|
||||
|
||||
assert scraper.timeout == 60
|
||||
|
||||
|
||||
class TestRSSScraperIsArticleRelevant:
|
||||
"""Tests for article relevance filtering."""
|
||||
|
||||
def test_relevant_article_with_keyword(self):
|
||||
"""Test that article with keyword is relevant."""
|
||||
scraper = RSSScraper()
|
||||
|
||||
title = "Arsenal wins Premier League match"
|
||||
content = "Great performance by the team"
|
||||
|
||||
assert scraper._is_article_relevant(title, content) is True
|
||||
|
||||
def test_relevant_article_multiple_keywords(self):
|
||||
"""Test article with multiple keywords."""
|
||||
scraper = RSSScraper()
|
||||
|
||||
title = "Champions League: Real Madrid vs Barcelona"
|
||||
content = "Soccer match preview"
|
||||
|
||||
assert scraper._is_article_relevant(title, content) is True
|
||||
|
||||
def test_irrelevant_article(self):
|
||||
"""Test that irrelevant article is filtered out."""
|
||||
scraper = RSSScraper()
|
||||
|
||||
title = "Technology news: New iPhone released"
|
||||
content = "Apple announced new products"
|
||||
|
||||
assert scraper._is_article_relevant(title, content) is False
|
||||
|
||||
def test_case_insensitive_matching(self):
|
||||
"""Test that keyword matching is case insensitive."""
|
||||
scraper = RSSScraper()
|
||||
|
||||
title = "FOOTBALL MATCH: TEAM A VS TEAM B"
|
||||
content = "SOCCER game details"
|
||||
|
||||
assert scraper._is_article_relevant(title, content) is True
|
||||
|
||||
|
||||
class TestRSSScraperParsePublishedDate:
|
||||
"""Tests for date parsing."""
|
||||
|
||||
def test_parse_valid_date(self):
|
||||
"""Test parsing a valid date string."""
|
||||
scraper = RSSScraper()
|
||||
|
||||
date_str = "Sat, 15 Jan 2026 10:30:00 +0000"
|
||||
parsed = scraper._parse_published_date(date_str)
|
||||
|
||||
assert isinstance(parsed, datetime)
|
||||
assert parsed.tzinfo is not None
|
||||
|
||||
def test_parse_invalid_date(self):
|
||||
"""Test parsing an invalid date string falls back to current time."""
|
||||
scraper = RSSScraper()
|
||||
|
||||
date_str = "invalid-date"
|
||||
parsed = scraper._parse_published_date(date_str)
|
||||
|
||||
assert isinstance(parsed, datetime)
|
||||
assert parsed.tzinfo is not None
|
||||
|
||||
|
||||
class TestRSSScraperParseFeed:
|
||||
"""Tests for RSS feed parsing."""
|
||||
|
||||
@patch('feedparser.parse')
|
||||
def test_parse_valid_feed(self, mock_parse):
|
||||
"""Test parsing a valid RSS feed."""
|
||||
# Mock feedparser response
|
||||
mock_feed = Mock()
|
||||
mock_feed.feed = {'title': 'ESPN'}
|
||||
mock_feed.bozo = False
|
||||
mock_feed.entries = [
|
||||
Mock(
|
||||
id='article-1',
|
||||
title='Football match preview',
|
||||
summary='Team A vs Team B',
|
||||
published='Sat, 15 Jan 2026 10:30:00 +0000',
|
||||
link='http://example.com/article-1'
|
||||
)
|
||||
]
|
||||
mock_parse.return_value = mock_feed
|
||||
|
||||
scraper = RSSScraper()
|
||||
articles = scraper._parse_feed('http://example.com/rss')
|
||||
|
||||
assert len(articles) >= 0
|
||||
mock_parse.assert_called_once()
|
||||
|
||||
@patch('feedparser.parse')
|
||||
def test_parse_feed_with_bozo_error(self, mock_parse):
|
||||
"""Test parsing a feed with XML errors."""
|
||||
# Mock feedparser response with bozo error
|
||||
mock_feed = Mock()
|
||||
mock_feed.feed = {'title': 'ESPN'}
|
||||
mock_feed.bozo = True
|
||||
mock_feed.entries = []
|
||||
mock_parse.return_value = mock_feed
|
||||
|
||||
scraper = RSSScraper()
|
||||
articles = scraper._parse_feed('http://example.com/rss')
|
||||
|
||||
# Should not crash, but log warning
|
||||
assert isinstance(articles, list)
|
||||
|
||||
@patch('feedparser.parse')
|
||||
def test_parse_feed_filters_irrelevant_articles(self, mock_parse):
|
||||
"""Test that irrelevant articles are filtered out."""
|
||||
# Mock feedparser response
|
||||
mock_feed = Mock()
|
||||
mock_feed.feed = {'title': 'ESPN'}
|
||||
mock_feed.bozo = False
|
||||
mock_feed.entries = [
|
||||
Mock(
|
||||
id='article-1',
|
||||
title='Football news',
|
||||
summary='Match result',
|
||||
published='Sat, 15 Jan 2026 10:30:00 +0000',
|
||||
link='http://example.com/article-1'
|
||||
),
|
||||
Mock(
|
||||
id='article-2',
|
||||
title='Technology news',
|
||||
summary='New iPhone',
|
||||
published='Sat, 15 Jan 2026 11:30:00 +0000',
|
||||
link='http://example.com/article-2'
|
||||
)
|
||||
]
|
||||
mock_parse.return_value = mock_feed
|
||||
|
||||
scraper = RSSScraper()
|
||||
articles = scraper._parse_feed('http://example.com/rss')
|
||||
|
||||
# Only football article should be included
|
||||
football_articles = [a for a in articles if 'football' in a.title.lower()]
|
||||
assert len(football_articles) >= 0
|
||||
|
||||
|
||||
class TestRSSScraperScrapeAllSources:
|
||||
"""Tests for scraping all sources."""
|
||||
|
||||
@patch('app.scrapers.rss_scraper.RSSScraper._parse_feed')
|
||||
def test_scrape_all_sources(self, mock_parse_feed):
|
||||
"""Test scraping all configured sources."""
|
||||
# Mock feed parsing
|
||||
mock_parse_feed.return_value = [
|
||||
RSSArticleData(
|
||||
article_id='article-1',
|
||||
title='Football news',
|
||||
content='Match result',
|
||||
published_at=datetime.now(timezone.utc),
|
||||
source_url='http://example.com/rss',
|
||||
match_id=None,
|
||||
source='ESPN'
|
||||
)
|
||||
]
|
||||
|
||||
scraper = RSSScraper()
|
||||
articles = scraper.scrape_all_sources()
|
||||
|
||||
assert len(articles) >= 0
|
||||
assert mock_parse_feed.call_count == len(scraper.rss_sources)
|
||||
|
||||
@patch('app.scrapers.rss_scraper.RSSScraper._parse_feed')
|
||||
def test_scrape_all_sources_with_match_id(self, mock_parse_feed):
|
||||
"""Test scraping with match ID."""
|
||||
mock_parse_feed.return_value = [
|
||||
RSSArticleData(
|
||||
article_id='article-1',
|
||||
title='Football news',
|
||||
content='Match result',
|
||||
published_at=datetime.now(timezone.utc),
|
||||
source_url='http://example.com/rss',
|
||||
match_id=None,
|
||||
source='ESPN'
|
||||
)
|
||||
]
|
||||
|
||||
scraper = RSSScraper()
|
||||
match_id = 123
|
||||
articles = scraper.scrape_all_sources(match_id=match_id)
|
||||
|
||||
# All articles should have match_id set
|
||||
for article in articles:
|
||||
assert article.match_id == match_id
|
||||
|
||||
@patch('app.scrapers.rss_scraper.RSSScraper._parse_feed')
|
||||
def test_scrape_all_sources_continues_on_error(self, mock_parse_feed):
|
||||
"""Test that scraper continues on source errors."""
|
||||
# Make second source raise an error
|
||||
mock_parse_feed.side_effect = [
|
||||
[], # First source succeeds
|
||||
Exception("Network error"), # Second source fails
|
||||
[] # Third source succeeds
|
||||
]
|
||||
|
||||
scraper = RSSScraper()
|
||||
articles = scraper.scrape_all_sources()
|
||||
|
||||
# Should have collected from successful sources
|
||||
assert isinstance(articles, list)
|
||||
|
||||
|
||||
class TestRSSScraperSaveToDatabase:
|
||||
"""Tests for saving articles to database."""
|
||||
|
||||
@patch('app.scrapers.rss_scraper.Session')
|
||||
def test_save_articles_to_db(self, mock_session_class):
|
||||
"""Test saving articles to database."""
|
||||
mock_db = Mock()
|
||||
mock_session_class.return_value = mock_db
|
||||
|
||||
# Mock query to return None (article doesn't exist)
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
scraper = RSSScraper()
|
||||
articles = [
|
||||
RSSArticleData(
|
||||
article_id='article-1',
|
||||
title='Football news',
|
||||
content='Match result',
|
||||
published_at=datetime.now(timezone.utc),
|
||||
source_url='http://example.com/rss',
|
||||
match_id=None,
|
||||
source='ESPN'
|
||||
)
|
||||
]
|
||||
|
||||
scraper.save_articles_to_db(articles, mock_db)
|
||||
|
||||
# Verify commit was called
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
@patch('app.scrapers.rss_scraper.Session')
|
||||
def test_save_articles_skips_duplicates(self, mock_session_class):
|
||||
"""Test that duplicate articles are skipped."""
|
||||
mock_db = Mock()
|
||||
mock_session_class.return_value = mock_db
|
||||
|
||||
# Mock query to return existing article
|
||||
mock_existing = Mock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_existing
|
||||
|
||||
scraper = RSSScraper()
|
||||
articles = [
|
||||
RSSArticleData(
|
||||
article_id='article-1',
|
||||
title='Football news',
|
||||
content='Match result',
|
||||
published_at=datetime.now(timezone.utc),
|
||||
source_url='http://example.com/rss',
|
||||
match_id=None,
|
||||
source='ESPN'
|
||||
)
|
||||
]
|
||||
|
||||
scraper.save_articles_to_db(articles, mock_db)
|
||||
|
||||
# Should not add duplicate, but still commit
|
||||
assert mock_db.add.call_count == 0
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
@patch('app.scrapers.rss_scraper.Session')
|
||||
def test_save_articles_handles_db_error(self, mock_session_class):
|
||||
"""Test that database errors are handled properly."""
|
||||
mock_db = Mock()
|
||||
mock_session_class.return_value = mock_db
|
||||
|
||||
# Mock query and commit to raise error
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
mock_db.commit.side_effect = Exception("Database error")
|
||||
|
||||
scraper = RSSScraper()
|
||||
articles = [
|
||||
RSSArticleData(
|
||||
article_id='article-1',
|
||||
title='Football news',
|
||||
content='Match result',
|
||||
published_at=datetime.now(timezone.utc),
|
||||
source_url='http://example.com/rss',
|
||||
match_id=None,
|
||||
source='ESPN'
|
||||
)
|
||||
]
|
||||
|
||||
# Should raise exception
|
||||
with pytest.raises(Exception):
|
||||
scraper.save_articles_to_db(articles, mock_db)
|
||||
|
||||
# Verify rollback was called
|
||||
mock_db.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestCreateRSSScraper:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_create_default_scraper(self):
|
||||
"""Test creating a scraper with default config."""
|
||||
scraper = create_rss_scraper()
|
||||
|
||||
assert isinstance(scraper, RSSScraper)
|
||||
assert len(scraper.rss_sources) > 0
|
||||
|
||||
def test_create_custom_scraper(self):
|
||||
"""Test creating a scraper with custom config."""
|
||||
custom_sources = ["http://custom.com/rss"]
|
||||
custom_keywords = ["football"]
|
||||
|
||||
scraper = create_rss_scraper(
|
||||
rss_sources=custom_sources,
|
||||
keywords=custom_keywords
|
||||
)
|
||||
|
||||
assert scraper.rss_sources == custom_sources
|
||||
assert scraper.keywords == custom_keywords
|
||||
246
backend/tests/test_scraping_worker.py
Normal file
246
backend/tests/test_scraping_worker.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Tests for scraping worker.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.workers.scraping_worker import (
|
||||
ScrapingWorker,
|
||||
create_scraping_worker
|
||||
)
|
||||
|
||||
|
||||
class TestScrapingWorker:
|
||||
"""Tests for ScrapingWorker class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test scraping worker initialization."""
|
||||
worker = ScrapingWorker(
|
||||
twitter_bearer_token="test_token",
|
||||
reddit_client_id="test_id",
|
||||
reddit_client_secret="test_secret"
|
||||
)
|
||||
|
||||
assert worker.twitter_bearer_token == "test_token"
|
||||
assert worker.reddit_client_id == "test_id"
|
||||
assert worker.reddit_client_secret == "test_secret"
|
||||
assert worker.twitter_scraper is None
|
||||
assert worker.reddit_scraper is None
|
||||
|
||||
def test_execute_scraping_task_twitter(self):
|
||||
"""Test executing a Twitter scraping task."""
|
||||
# Create worker
|
||||
worker = ScrapingWorker(
|
||||
twitter_bearer_token="test_token",
|
||||
reddit_client_id="test_id",
|
||||
reddit_client_secret="test_secret"
|
||||
)
|
||||
|
||||
# Mock Twitter scraper
|
||||
mock_twitter_scraper = Mock()
|
||||
worker.twitter_scraper = mock_twitter_scraper
|
||||
mock_twitter_scraper.scrape_and_save.return_value = [Mock()] * 50
|
||||
|
||||
# Mock database session
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'source': 'twitter',
|
||||
'keywords': ['#MatchName'],
|
||||
'priority': 'normal'
|
||||
}
|
||||
|
||||
result = worker.execute_scraping_task(task, mock_db)
|
||||
|
||||
# Verify scraping called
|
||||
mock_twitter_scraper.scrape_and_save.assert_called_once_with(
|
||||
match_id=123,
|
||||
keywords=['#MatchName'],
|
||||
db=mock_db,
|
||||
max_results=100
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result['collected_count'] == 50
|
||||
assert result['status'] == 'success'
|
||||
assert result['metadata']['source'] == 'twitter'
|
||||
assert result['metadata']['match_id'] == 123
|
||||
|
||||
def test_execute_scraping_task_reddit(self):
|
||||
"""Test executing a Reddit scraping task."""
|
||||
# Create worker
|
||||
worker = ScrapingWorker(
|
||||
twitter_bearer_token="test_token",
|
||||
reddit_client_id="test_id",
|
||||
reddit_client_secret="test_secret"
|
||||
)
|
||||
|
||||
# Mock Reddit scraper
|
||||
mock_reddit_scraper = Mock()
|
||||
worker.reddit_scraper = mock_reddit_scraper
|
||||
mock_reddit_scraper.scrape_and_save.return_value = {
|
||||
'posts': [Mock()] * 20,
|
||||
'comments': [Mock()] * 30
|
||||
}
|
||||
|
||||
# Mock database session
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 456,
|
||||
'source': 'reddit',
|
||||
'keywords': ['Ligue1'],
|
||||
'priority': 'vip'
|
||||
}
|
||||
|
||||
result = worker.execute_scraping_task(task, mock_db)
|
||||
|
||||
# Verify scraping called
|
||||
mock_reddit_scraper.scrape_and_save.assert_called_once_with(
|
||||
match_id=456,
|
||||
db=mock_db,
|
||||
keywords=['Ligue1'],
|
||||
scrape_comments=True
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result['collected_count'] == 50 # 20 posts + 30 comments
|
||||
assert result['status'] == 'success'
|
||||
assert result['metadata']['source'] == 'reddit'
|
||||
assert result['metadata']['match_id'] == 456
|
||||
assert result['metadata']['posts_count'] == 20
|
||||
assert result['metadata']['comments_count'] == 30
|
||||
|
||||
def test_execute_scraping_task_unknown_source(self):
|
||||
"""Test executing task with unknown source."""
|
||||
worker = ScrapingWorker(
|
||||
twitter_bearer_token="test_token",
|
||||
reddit_client_id="test_id",
|
||||
reddit_client_secret="test_secret"
|
||||
)
|
||||
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Execute task with unknown source
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'source': 'unknown',
|
||||
'keywords': ['#MatchName']
|
||||
}
|
||||
|
||||
result = worker.execute_scraping_task(task, mock_db)
|
||||
|
||||
# Verify error result
|
||||
assert result['collected_count'] == 0
|
||||
assert result['status'] == 'error'
|
||||
assert 'error' in result
|
||||
assert 'Unknown source' in result['error']
|
||||
|
||||
def test_execute_scraping_task_twitter_error(self):
|
||||
"""Test handling Twitter scraping errors."""
|
||||
worker = ScrapingWorker(
|
||||
twitter_bearer_token="test_token",
|
||||
reddit_client_id="test_id",
|
||||
reddit_client_secret="test_secret"
|
||||
)
|
||||
|
||||
# Mock Twitter scraper with error
|
||||
mock_twitter_scraper = Mock()
|
||||
worker.twitter_scraper = mock_twitter_scraper
|
||||
mock_twitter_scraper.scrape_and_save.side_effect = Exception("API Error")
|
||||
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'source': 'twitter',
|
||||
'keywords': ['#MatchName']
|
||||
}
|
||||
|
||||
result = worker.execute_scraping_task(task, mock_db)
|
||||
|
||||
# Verify error handling
|
||||
assert result['collected_count'] == 0
|
||||
assert result['status'] == 'error'
|
||||
assert 'error' in result
|
||||
|
||||
@patch('app.workers.scraping_worker.create_twitter_scraper')
|
||||
def test_get_twitter_scraper_lazy_initialization(self, mock_create_scraper):
|
||||
"""Test lazy initialization of Twitter scraper."""
|
||||
worker = ScrapingWorker(
|
||||
twitter_bearer_token="test_token",
|
||||
reddit_client_id="test_id",
|
||||
reddit_client_secret="test_secret"
|
||||
)
|
||||
|
||||
# First call should create scraper
|
||||
mock_scraper_instance = Mock()
|
||||
mock_create_scraper.return_value = mock_scraper_instance
|
||||
|
||||
scraper1 = worker._get_twitter_scraper()
|
||||
|
||||
# Verify creation
|
||||
mock_create_scraper.assert_called_once_with(
|
||||
bearer_token="test_token",
|
||||
vip_match_ids=[]
|
||||
)
|
||||
assert scraper1 == mock_scraper_instance
|
||||
|
||||
# Second call should return same instance
|
||||
scraper2 = worker._get_twitter_scraper()
|
||||
assert scraper2 == scraper1
|
||||
|
||||
# Verify not created again
|
||||
assert mock_create_scraper.call_count == 1
|
||||
|
||||
@patch('app.workers.scraping_worker.create_reddit_scraper')
|
||||
def test_get_reddit_scraper_lazy_initialization(self, mock_create_scraper):
|
||||
"""Test lazy initialization of Reddit scraper."""
|
||||
worker = ScrapingWorker(
|
||||
twitter_bearer_token="test_token",
|
||||
reddit_client_id="test_id",
|
||||
reddit_client_secret="test_secret"
|
||||
)
|
||||
|
||||
# First call should create scraper
|
||||
mock_scraper_instance = Mock()
|
||||
mock_create_scraper.return_value = mock_scraper_instance
|
||||
|
||||
scraper1 = worker._get_reddit_scraper()
|
||||
|
||||
# Verify creation
|
||||
mock_create_scraper.assert_called_once_with(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret"
|
||||
)
|
||||
assert scraper1 == mock_scraper_instance
|
||||
|
||||
# Second call should return same instance
|
||||
scraper2 = worker._get_reddit_scraper()
|
||||
assert scraper2 == scraper1
|
||||
|
||||
# Verify not created again
|
||||
assert mock_create_scraper.call_count == 1
|
||||
|
||||
|
||||
class TestCreateScrapingWorker:
|
||||
"""Tests for create_scraping_worker factory function."""
|
||||
|
||||
def test_create_scraping_worker(self):
|
||||
"""Test creating a scraping worker."""
|
||||
worker = create_scraping_worker(
|
||||
twitter_bearer_token="token123",
|
||||
reddit_client_id="id456",
|
||||
reddit_client_secret="secret789"
|
||||
)
|
||||
|
||||
assert isinstance(worker, ScrapingWorker)
|
||||
assert worker.twitter_bearer_token == "token123"
|
||||
assert worker.reddit_client_id == "id456"
|
||||
assert worker.reddit_client_secret == "secret789"
|
||||
201
backend/tests/test_sentiment_analyzer.py
Normal file
201
backend/tests/test_sentiment_analyzer.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Tests for Sentiment Analyzer
|
||||
|
||||
This module tests the VADER sentiment analyzer functionality.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from app.ml.sentiment_analyzer import (
|
||||
analyze_sentiment,
|
||||
analyze_sentiment_batch,
|
||||
calculate_aggregated_metrics,
|
||||
classify_sentiment,
|
||||
test_analyzer_performance
|
||||
)
|
||||
|
||||
|
||||
class TestClassifySentiment:
|
||||
"""Test sentiment classification based on compound score."""
|
||||
|
||||
def test_positive_classification(self):
|
||||
"""Test classification of positive sentiment."""
|
||||
assert classify_sentiment(0.5) == 'positive'
|
||||
assert classify_sentiment(0.05) == 'positive'
|
||||
assert classify_sentiment(1.0) == 'positive'
|
||||
|
||||
def test_negative_classification(self):
|
||||
"""Test classification of negative sentiment."""
|
||||
assert classify_sentiment(-0.5) == 'negative'
|
||||
assert classify_sentiment(-0.05) == 'negative'
|
||||
assert classify_sentiment(-1.0) == 'negative'
|
||||
|
||||
def test_neutral_classification(self):
|
||||
"""Test classification of neutral sentiment."""
|
||||
assert classify_sentiment(0.0) == 'neutral'
|
||||
assert classify_sentiment(0.04) == 'neutral'
|
||||
assert classify_sentiment(-0.04) == 'neutral'
|
||||
|
||||
|
||||
class TestAnalyzeSentiment:
|
||||
"""Test single text sentiment analysis."""
|
||||
|
||||
def test_positive_text(self):
|
||||
"""Test analysis of positive text."""
|
||||
text = "I love this game! It's absolutely amazing and wonderful!"
|
||||
result = analyze_sentiment(text)
|
||||
|
||||
assert result['sentiment'] == 'positive'
|
||||
assert result['compound'] > 0.05
|
||||
assert 0 <= result['positive'] <= 1
|
||||
assert 0 <= result['negative'] <= 1
|
||||
assert 0 <= result['neutral'] <= 1
|
||||
|
||||
def test_negative_text(self):
|
||||
"""Test analysis of negative text."""
|
||||
text = "This is terrible! I hate it and it's the worst ever."
|
||||
result = analyze_sentiment(text)
|
||||
|
||||
assert result['sentiment'] == 'negative'
|
||||
assert result['compound'] < -0.05
|
||||
assert 0 <= result['positive'] <= 1
|
||||
assert 0 <= result['negative'] <= 1
|
||||
assert 0 <= result['neutral'] <= 1
|
||||
|
||||
def test_neutral_text(self):
|
||||
"""Test analysis of neutral text."""
|
||||
text = "The game is okay. Nothing special but nothing bad."
|
||||
result = analyze_sentiment(text)
|
||||
|
||||
assert result['sentiment'] in ['positive', 'negative', 'neutral']
|
||||
assert 0 <= result['compound'] <= 1
|
||||
assert 0 <= result['positive'] <= 1
|
||||
assert 0 <= result['negative'] <= 1
|
||||
assert 0 <= result['neutral'] <= 1
|
||||
|
||||
def test_empty_text_raises_error(self):
|
||||
"""Test that empty text raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
analyze_sentiment("")
|
||||
|
||||
def test_invalid_input_raises_error(self):
|
||||
"""Test that invalid input raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
analyze_sentiment(None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
analyze_sentiment(123)
|
||||
|
||||
|
||||
class TestAnalyzeSentimentBatch:
|
||||
"""Test batch sentiment analysis."""
|
||||
|
||||
def test_batch_analysis_multiple_texts(self):
|
||||
"""Test analysis of multiple texts."""
|
||||
texts = [
|
||||
"I love this!",
|
||||
"This is terrible!",
|
||||
"It's okay."
|
||||
]
|
||||
results = analyze_sentiment_batch(texts)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all('compound' in r for r in results)
|
||||
assert all('positive' in r for r in results)
|
||||
assert all('negative' in r for r in results)
|
||||
assert all('neutral' in r for r in results)
|
||||
assert all('sentiment' in r for r in results)
|
||||
|
||||
def test_batch_analysis_empty_list(self):
|
||||
"""Test batch analysis with empty list."""
|
||||
results = analyze_sentiment_batch([])
|
||||
assert results == []
|
||||
|
||||
def test_batch_analysis_with_errors(self):
|
||||
"""Test batch analysis handles errors gracefully."""
|
||||
texts = [
|
||||
"This is good!",
|
||||
"", # Invalid - should be handled gracefully
|
||||
"This is great!"
|
||||
]
|
||||
results = analyze_sentiment_batch(texts)
|
||||
|
||||
assert len(results) == 3
|
||||
# The invalid text should return neutral sentiment
|
||||
assert results[1]['sentiment'] == 'neutral'
|
||||
assert results[1]['compound'] == 0.0
|
||||
|
||||
|
||||
class TestCalculateAggregatedMetrics:
|
||||
"""Test calculation of aggregated sentiment metrics."""
|
||||
|
||||
def test_aggregated_metrics_empty_list(self):
|
||||
"""Test aggregated metrics with empty list."""
|
||||
metrics = calculate_aggregated_metrics([])
|
||||
|
||||
assert metrics['total_count'] == 0
|
||||
assert metrics['positive_count'] == 0
|
||||
assert metrics['negative_count'] == 0
|
||||
assert metrics['neutral_count'] == 0
|
||||
assert metrics['positive_ratio'] == 0.0
|
||||
assert metrics['negative_ratio'] == 0.0
|
||||
assert metrics['neutral_ratio'] == 0.0
|
||||
assert metrics['average_compound'] == 0.0
|
||||
|
||||
def test_aggregated_metrics_all_positive(self):
|
||||
"""Test aggregated metrics with all positive sentiments."""
|
||||
sentiments = [
|
||||
{'compound': 0.5, 'sentiment': 'positive'},
|
||||
{'compound': 0.7, 'sentiment': 'positive'},
|
||||
{'compound': 0.9, 'sentiment': 'positive'}
|
||||
]
|
||||
metrics = calculate_aggregated_metrics(sentiments)
|
||||
|
||||
assert metrics['total_count'] == 3
|
||||
assert metrics['positive_count'] == 3
|
||||
assert metrics['negative_count'] == 0
|
||||
assert metrics['neutral_count'] == 0
|
||||
assert metrics['positive_ratio'] == 1.0
|
||||
assert metrics['average_compound'] == pytest.approx(0.7, rel=0.01)
|
||||
|
||||
def test_aggregated_metrics_mixed(self):
|
||||
"""Test aggregated metrics with mixed sentiments."""
|
||||
sentiments = [
|
||||
{'compound': 0.5, 'sentiment': 'positive'},
|
||||
{'compound': -0.5, 'sentiment': 'negative'},
|
||||
{'compound': 0.0, 'sentiment': 'neutral'},
|
||||
{'compound': 0.7, 'sentiment': 'positive'}
|
||||
]
|
||||
metrics = calculate_aggregated_metrics(sentiments)
|
||||
|
||||
assert metrics['total_count'] == 4
|
||||
assert metrics['positive_count'] == 2
|
||||
assert metrics['negative_count'] == 1
|
||||
assert metrics['neutral_count'] == 1
|
||||
assert metrics['positive_ratio'] == 0.5
|
||||
assert metrics['negative_ratio'] == 0.25
|
||||
assert metrics['neutral_ratio'] == 0.25
|
||||
assert metrics['average_compound'] == pytest.approx(0.175, rel=0.01)
|
||||
|
||||
|
||||
class TestAnalyzerPerformance:
|
||||
"""Test performance of the sentiment analyzer."""
|
||||
|
||||
def test_performance_1000_tweets(self):
|
||||
"""Test that 1000 tweets can be analyzed in less than 1 second."""
|
||||
time_taken = test_analyzer_performance(1000)
|
||||
|
||||
assert time_taken < 1.0, f"Analysis took {time_taken:.4f}s, expected < 1.0s"
|
||||
|
||||
def test_performance_5000_tweets(self):
|
||||
"""Test performance with 5000 tweets."""
|
||||
time_taken = test_analyzer_performance(5000)
|
||||
|
||||
# Should still be fast, but we allow some scaling
|
||||
assert time_taken < 5.0, f"Analysis took {time_taken:.4f}s, expected < 5.0s"
|
||||
|
||||
def test_performance_10000_tweets(self):
|
||||
"""Test performance with 10000 tweets."""
|
||||
time_taken = test_analyzer_performance(10000)
|
||||
|
||||
# Should still be very fast
|
||||
assert time_taken < 10.0, f"Analysis took {time_taken:.4f}s, expected < 10.0s"
|
||||
304
backend/tests/test_sentiment_service.py
Normal file
304
backend/tests/test_sentiment_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Tests for Sentiment Service
|
||||
|
||||
This module tests sentiment service functionality including database operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.sentiment_service import (
|
||||
process_tweet_sentiment,
|
||||
process_tweet_batch,
|
||||
process_reddit_post_sentiment,
|
||||
process_reddit_post_batch,
|
||||
get_sentiment_by_entity,
|
||||
get_sentiments_by_match,
|
||||
calculate_match_sentiment_metrics,
|
||||
get_global_sentiment_metrics
|
||||
)
|
||||
from app.models.sentiment_score import SentimentScore
|
||||
from app.models.tweet import Tweet
|
||||
from app.models.reddit_post import RedditPost
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tweet(db: Session):
|
||||
"""Create a sample tweet for testing."""
|
||||
tweet = Tweet(
|
||||
tweet_id="test_tweet_1",
|
||||
text="This is a test tweet about a great game!",
|
||||
created_at=datetime.utcnow(),
|
||||
retweet_count=5,
|
||||
like_count=10,
|
||||
match_id=1,
|
||||
source="twitter"
|
||||
)
|
||||
db.add(tweet)
|
||||
db.commit()
|
||||
db.refresh(tweet)
|
||||
return tweet
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tweets(db: Session):
|
||||
"""Create sample tweets for batch testing."""
|
||||
tweets = []
|
||||
texts = [
|
||||
"I love this team! Best performance ever!",
|
||||
"Terrible game today. Worst performance.",
|
||||
"It was an okay match. Nothing special.",
|
||||
"Amazing comeback! What a victory!",
|
||||
"Disappointed with the result."
|
||||
]
|
||||
for i, text in enumerate(texts):
|
||||
tweet = Tweet(
|
||||
tweet_id=f"test_tweet_{i}",
|
||||
text=text,
|
||||
created_at=datetime.utcnow(),
|
||||
retweet_count=i * 2,
|
||||
like_count=i * 3,
|
||||
match_id=1,
|
||||
source="twitter"
|
||||
)
|
||||
db.add(tweet)
|
||||
tweets.append(tweet)
|
||||
db.commit()
|
||||
for tweet in tweets:
|
||||
db.refresh(tweet)
|
||||
return tweets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_reddit_post(db: Session):
|
||||
"""Create a sample Reddit post for testing."""
|
||||
post = RedditPost(
|
||||
post_id="test_post_1",
|
||||
title="This is a test post about a great game!",
|
||||
text="The team played amazingly well today!",
|
||||
upvotes=15,
|
||||
created_at=datetime.utcnow(),
|
||||
match_id=1,
|
||||
subreddit="test_subreddit",
|
||||
source="reddit"
|
||||
)
|
||||
db.add(post)
|
||||
db.commit()
|
||||
db.refresh(post)
|
||||
return post
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_reddit_posts(db: Session):
|
||||
"""Create sample Reddit posts for batch testing."""
|
||||
posts = []
|
||||
titles = [
|
||||
"Great game today!",
|
||||
"Terrible performance",
|
||||
"Okay match",
|
||||
"Amazing victory",
|
||||
"Disappointed result"
|
||||
]
|
||||
texts = [
|
||||
"The team was amazing!",
|
||||
"Worst game ever",
|
||||
"Nothing special",
|
||||
"What a comeback!",
|
||||
"Not good enough"
|
||||
]
|
||||
for i, (title, text) in enumerate(zip(titles, texts)):
|
||||
post = RedditPost(
|
||||
post_id=f"test_post_{i}",
|
||||
title=title,
|
||||
text=text,
|
||||
upvotes=i * 5,
|
||||
created_at=datetime.utcnow(),
|
||||
match_id=1,
|
||||
subreddit="test_subreddit",
|
||||
source="reddit"
|
||||
)
|
||||
db.add(post)
|
||||
posts.append(post)
|
||||
db.commit()
|
||||
for post in posts:
|
||||
db.refresh(post)
|
||||
return posts
|
||||
|
||||
|
||||
class TestProcessTweetSentiment:
|
||||
"""Test processing single tweet sentiment."""
|
||||
|
||||
def test_process_tweet_sentiment_creates_record(self, db: Session, sample_tweet: Tweet):
|
||||
"""Test that processing tweet sentiment creates a database record."""
|
||||
sentiment = process_tweet_sentiment(db, sample_tweet.tweet_id, sample_tweet.text)
|
||||
|
||||
assert sentiment.id is not None
|
||||
assert sentiment.entity_id == sample_tweet.tweet_id
|
||||
assert sentiment.entity_type == 'tweet'
|
||||
assert sentiment.sentiment_type in ['positive', 'negative', 'neutral']
|
||||
assert -1 <= sentiment.score <= 1
|
||||
assert 0 <= sentiment.positive <= 1
|
||||
assert 0 <= sentiment.negative <= 1
|
||||
assert 0 <= sentiment.neutral <= 1
|
||||
|
||||
def test_process_tweet_sentiment_positive(self, db: Session, sample_tweet: Tweet):
|
||||
"""Test processing positive tweet sentiment."""
|
||||
positive_text = "I absolutely love this team! Best performance ever!"
|
||||
sentiment = process_tweet_sentiment(db, "test_pos", positive_text)
|
||||
|
||||
assert sentiment.sentiment_type == 'positive'
|
||||
assert sentiment.score > 0.05
|
||||
|
||||
def test_process_tweet_sentiment_negative(self, db: Session, sample_tweet: Tweet):
|
||||
"""Test processing negative tweet sentiment."""
|
||||
negative_text = "This is terrible! Worst performance ever!"
|
||||
sentiment = process_tweet_sentiment(db, "test_neg", negative_text)
|
||||
|
||||
assert sentiment.sentiment_type == 'negative'
|
||||
assert sentiment.score < -0.05
|
||||
|
||||
|
||||
class TestProcessTweetBatch:
|
||||
"""Test batch processing of tweet sentiments."""
|
||||
|
||||
def test_process_tweet_batch(self, db: Session, sample_tweets: list[Tweet]):
|
||||
"""Test processing multiple tweets in batch."""
|
||||
sentiments = process_tweet_batch(db, sample_tweets)
|
||||
|
||||
assert len(sentiments) == 5
|
||||
assert all(s.entity_type == 'tweet' for s in sentiments)
|
||||
assert all(s.id is not None for s in sentiments)
|
||||
|
||||
def test_process_tweet_batch_empty(self, db: Session):
|
||||
"""Test processing empty batch."""
|
||||
sentiments = process_tweet_batch(db, [])
|
||||
assert sentiments == []
|
||||
|
||||
def test_process_tweet_batch_sentiments_calculated(self, db: Session, sample_tweets: list[Tweet]):
|
||||
"""Test that sentiments are correctly calculated for batch."""
|
||||
sentiments = process_tweet_batch(db, sample_tweets)
|
||||
|
||||
# Check that at least one positive and one negative sentiment exists
|
||||
sentiment_types = [s.sentiment_type for s in sentiments]
|
||||
assert 'positive' in sentiment_types
|
||||
assert 'negative' in sentiment_types
|
||||
|
||||
|
||||
class TestProcessRedditPostSentiment:
|
||||
"""Test processing single Reddit post sentiment."""
|
||||
|
||||
def test_process_reddit_post_sentiment_creates_record(self, db: Session, sample_reddit_post: RedditPost):
|
||||
"""Test that processing Reddit post sentiment creates a database record."""
|
||||
sentiment = process_reddit_post_sentiment(
|
||||
db,
|
||||
sample_reddit_post.post_id,
|
||||
f"{sample_reddit_post.title} {sample_reddit_post.text}"
|
||||
)
|
||||
|
||||
assert sentiment.id is not None
|
||||
assert sentiment.entity_id == sample_reddit_post.post_id
|
||||
assert sentiment.entity_type == 'reddit_post'
|
||||
assert sentiment.sentiment_type in ['positive', 'negative', 'neutral']
|
||||
|
||||
|
||||
class TestProcessRedditPostBatch:
|
||||
"""Test batch processing of Reddit post sentiments."""
|
||||
|
||||
def test_process_reddit_post_batch(self, db: Session, sample_reddit_posts: list[RedditPost]):
|
||||
"""Test processing multiple Reddit posts in batch."""
|
||||
sentiments = process_reddit_post_batch(db, sample_reddit_posts)
|
||||
|
||||
assert len(sentiments) == 5
|
||||
assert all(s.entity_type == 'reddit_post' for s in sentiments)
|
||||
assert all(s.id is not None for s in sentiments)
|
||||
|
||||
def test_process_reddit_post_batch_empty(self, db: Session):
|
||||
"""Test processing empty batch."""
|
||||
sentiments = process_reddit_post_batch(db, [])
|
||||
assert sentiments == []
|
||||
|
||||
|
||||
class TestGetSentimentByEntity:
|
||||
"""Test retrieving sentiment by entity."""
|
||||
|
||||
def test_get_sentiment_by_entity_found(self, db: Session, sample_tweet: Tweet):
|
||||
"""Test retrieving existing sentiment by entity."""
|
||||
process_tweet_sentiment(db, sample_tweet.tweet_id, sample_tweet.text)
|
||||
|
||||
sentiment = get_sentiment_by_entity(
|
||||
db,
|
||||
sample_tweet.tweet_id,
|
||||
'tweet'
|
||||
)
|
||||
|
||||
assert sentiment is not None
|
||||
assert sentiment.entity_id == sample_tweet.tweet_id
|
||||
|
||||
def test_get_sentiment_by_entity_not_found(self, db: Session):
|
||||
"""Test retrieving non-existent sentiment."""
|
||||
sentiment = get_sentiment_by_entity(db, "nonexistent_id", "tweet")
|
||||
assert sentiment is None
|
||||
|
||||
|
||||
class TestGetSentimentsByMatch:
|
||||
"""Test retrieving sentiments by match."""
|
||||
|
||||
def test_get_sentiments_by_match(self, db: Session, sample_tweets: list[Tweet]):
|
||||
"""Test retrieving all sentiments for a match."""
|
||||
process_tweet_batch(db, sample_tweets)
|
||||
|
||||
sentiments = get_sentiments_by_match(db, 1)
|
||||
|
||||
assert len(sentiments) == 5
|
||||
assert all(s.entity_type == 'tweet' for s in sentiments)
|
||||
|
||||
def test_get_sentiments_by_match_empty(self, db: Session):
|
||||
"""Test retrieving sentiments for match with no data."""
|
||||
sentiments = get_sentiments_by_match(db, 999)
|
||||
assert sentiments == []
|
||||
|
||||
|
||||
class TestCalculateMatchSentimentMetrics:
|
||||
"""Test calculating sentiment metrics for a match."""
|
||||
|
||||
def test_calculate_match_sentiment_metrics(self, db: Session, sample_tweets: list[Tweet]):
|
||||
"""Test calculating metrics for a match."""
|
||||
process_tweet_batch(db, sample_tweets)
|
||||
|
||||
metrics = calculate_match_sentiment_metrics(db, 1)
|
||||
|
||||
assert metrics['match_id'] == 1
|
||||
assert metrics['total_count'] == 5
|
||||
assert metrics['positive_count'] + metrics['negative_count'] + metrics['neutral_count'] == 5
|
||||
assert metrics['positive_ratio'] + metrics['negative_ratio'] + metrics['neutral_ratio'] == 1.0
|
||||
assert -1 <= metrics['average_compound'] <= 1
|
||||
|
||||
def test_calculate_match_sentiment_metrics_empty(self, db: Session):
|
||||
"""Test calculating metrics for match with no data."""
|
||||
metrics = calculate_match_sentiment_metrics(db, 999)
|
||||
|
||||
assert metrics['match_id'] == 999
|
||||
assert metrics['total_count'] == 0
|
||||
assert metrics['average_compound'] == 0.0
|
||||
|
||||
|
||||
class TestGetGlobalSentimentMetrics:
|
||||
"""Test calculating global sentiment metrics."""
|
||||
|
||||
def test_get_global_sentiment_metrics(self, db: Session, sample_tweets: list[Tweet]):
|
||||
"""Test calculating global metrics."""
|
||||
process_tweet_batch(db, sample_tweets)
|
||||
|
||||
metrics = get_global_sentiment_metrics(db)
|
||||
|
||||
assert metrics['total_count'] == 5
|
||||
assert metrics['positive_count'] + metrics['negative_count'] + metrics['neutral_count'] == 5
|
||||
assert metrics['positive_ratio'] + metrics['negative_ratio'] + metrics['neutral_ratio'] == 1.0
|
||||
|
||||
def test_get_global_sentiment_metrics_empty(self, db: Session):
|
||||
"""Test calculating global metrics with no data."""
|
||||
metrics = get_global_sentiment_metrics(db)
|
||||
|
||||
assert metrics['total_count'] == 0
|
||||
assert metrics['average_compound'] == 0.0
|
||||
251
backend/tests/test_sentiment_worker.py
Normal file
251
backend/tests/test_sentiment_worker.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Tests for sentiment analysis worker.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.workers.sentiment_worker import (
|
||||
SentimentWorker,
|
||||
create_sentiment_worker
|
||||
)
|
||||
|
||||
|
||||
class TestSentimentWorker:
|
||||
"""Tests for SentimentWorker class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test sentiment worker initialization."""
|
||||
worker = SentimentWorker()
|
||||
|
||||
# No specific initialization required for sentiment worker
|
||||
assert worker is not None
|
||||
|
||||
@patch('app.workers.sentiment_worker.process_tweet_batch')
|
||||
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
|
||||
def test_execute_sentiment_analysis_task_twitter(
|
||||
self,
|
||||
mock_get_sentiment,
|
||||
mock_process_batch
|
||||
):
|
||||
"""Test executing a Twitter sentiment analysis task."""
|
||||
# Create worker
|
||||
worker = SentimentWorker()
|
||||
|
||||
# Mock database session
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Mock tweet query
|
||||
mock_tweets = [Mock()] * 100
|
||||
mock_db.query.return_value.filter.return_value.all.return_value = mock_tweets
|
||||
|
||||
# Mock existing sentiment (none exist)
|
||||
mock_get_sentiment.return_value = None
|
||||
|
||||
# Mock batch processing
|
||||
mock_process_batch.return_value = [Mock()] * 100
|
||||
|
||||
# Mock calculate metrics
|
||||
with patch.object(
|
||||
worker,
|
||||
'_calculate_sentiment_metrics',
|
||||
return_value={
|
||||
'total_count': 100,
|
||||
'positive_count': 60,
|
||||
'negative_count': 20,
|
||||
'neutral_count': 20,
|
||||
'average_compound': 0.35
|
||||
}
|
||||
):
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'source': 'twitter',
|
||||
'entity_ids': ['tweet1', 'tweet2']
|
||||
}
|
||||
|
||||
result = worker.execute_sentiment_analysis_task(task, mock_db)
|
||||
|
||||
# Verify result
|
||||
assert result['analyzed_count'] == 100
|
||||
assert result['status'] == 'success'
|
||||
assert result['metrics']['total_count'] == 100
|
||||
assert result['metrics']['positive_count'] == 60
|
||||
|
||||
@patch('app.workers.sentiment_worker.process_reddit_post_batch')
|
||||
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
|
||||
def test_execute_sentiment_analysis_task_reddit(
|
||||
self,
|
||||
mock_get_sentiment,
|
||||
mock_process_batch
|
||||
):
|
||||
"""Test executing a Reddit sentiment analysis task."""
|
||||
# Create worker
|
||||
worker = SentimentWorker()
|
||||
|
||||
# Mock database session
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Mock Reddit post query
|
||||
mock_posts = [Mock()] * 50
|
||||
mock_db.query.return_value.filter.return_value.all.return_value = mock_posts
|
||||
|
||||
# Mock existing sentiment (none exist)
|
||||
mock_get_sentiment.return_value = None
|
||||
|
||||
# Mock batch processing
|
||||
mock_process_batch.return_value = [Mock()] * 50
|
||||
|
||||
# Mock calculate metrics
|
||||
with patch.object(
|
||||
worker,
|
||||
'_calculate_sentiment_metrics',
|
||||
return_value={
|
||||
'total_count': 50,
|
||||
'positive_count': 25,
|
||||
'negative_count': 15,
|
||||
'neutral_count': 10,
|
||||
'average_compound': 0.2
|
||||
}
|
||||
):
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 456,
|
||||
'source': 'reddit',
|
||||
'entity_ids': ['post1', 'post2']
|
||||
}
|
||||
|
||||
result = worker.execute_sentiment_analysis_task(task, mock_db)
|
||||
|
||||
# Verify result
|
||||
assert result['analyzed_count'] == 50
|
||||
assert result['status'] == 'success'
|
||||
assert result['metrics']['total_count'] == 50
|
||||
|
||||
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
|
||||
def test_execute_sentiment_analysis_task_all_analyzed(
|
||||
self,
|
||||
mock_get_sentiment
|
||||
):
|
||||
"""Test executing task when all items already analyzed."""
|
||||
worker = SentimentWorker()
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Mock tweets
|
||||
mock_tweets = [Mock()] * 100
|
||||
mock_db.query.return_value.filter.return_value.all.return_value = mock_tweets
|
||||
|
||||
# Mock existing sentiment (all exist)
|
||||
mock_get_sentiment.return_value = Mock()
|
||||
|
||||
# Mock calculate metrics from existing
|
||||
with patch.object(
|
||||
worker,
|
||||
'_calculate_metrics_from_existing',
|
||||
return_value={
|
||||
'total_count': 100,
|
||||
'average_compound': 0.4
|
||||
}
|
||||
):
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'source': 'twitter',
|
||||
'entity_ids': ['tweet1', 'tweet2']
|
||||
}
|
||||
|
||||
result = worker.execute_sentiment_analysis_task(task, mock_db)
|
||||
|
||||
# Verify result (0 new items analyzed)
|
||||
assert result['analyzed_count'] == 0
|
||||
assert result['status'] == 'success'
|
||||
|
||||
def test_execute_sentiment_analysis_task_unknown_source(self):
|
||||
"""Test executing task with unknown source."""
|
||||
worker = SentimentWorker()
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Execute task with unknown source
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'source': 'unknown',
|
||||
'entity_ids': ['item1']
|
||||
}
|
||||
|
||||
result = worker.execute_sentiment_analysis_task(task, mock_db)
|
||||
|
||||
# Verify error result
|
||||
assert result['analyzed_count'] == 0
|
||||
assert result['status'] == 'error'
|
||||
assert 'error' in result
|
||||
assert 'Unknown source' in result['error']
|
||||
|
||||
@patch('app.workers.sentiment_worker.process_tweet_batch')
|
||||
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
|
||||
def test_execute_sentiment_analysis_task_error_handling(
|
||||
self,
|
||||
mock_get_sentiment,
|
||||
mock_process_batch
|
||||
):
|
||||
"""Test error handling in sentiment analysis."""
|
||||
worker = SentimentWorker()
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Mock tweets
|
||||
mock_tweets = [Mock()] * 10
|
||||
mock_db.query.return_value.filter.return_value.all.return_value = mock_tweets
|
||||
|
||||
# Mock no existing sentiment
|
||||
mock_get_sentiment.return_value = None
|
||||
|
||||
# Mock processing error
|
||||
mock_process_batch.side_effect = Exception("Processing error")
|
||||
|
||||
# Execute task
|
||||
task = {
|
||||
'match_id': 123,
|
||||
'source': 'twitter',
|
||||
'entity_ids': ['tweet1', 'tweet2']
|
||||
}
|
||||
|
||||
result = worker.execute_sentiment_analysis_task(task, mock_db)
|
||||
|
||||
# Verify error handling
|
||||
assert result['analyzed_count'] == 0
|
||||
assert result['status'] == 'error'
|
||||
assert 'error' in result
|
||||
|
||||
@patch('app.workers.sentiment_worker.get_sentiment_by_entity')
|
||||
def test_execute_twitter_sentiment_analysis_no_tweets(
|
||||
self,
|
||||
mock_get_sentiment
|
||||
):
|
||||
"""Test Twitter sentiment analysis when no tweets found."""
|
||||
worker = SentimentWorker()
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
# Mock no tweets
|
||||
mock_db.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
# Execute task
|
||||
result = worker._execute_twitter_sentiment_analysis(
|
||||
match_id=123,
|
||||
entity_ids=['tweet1', 'tweet2'],
|
||||
db=mock_db
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result['analyzed_count'] == 0
|
||||
assert result['status'] == 'success'
|
||||
assert result['metrics']['total_count'] == 0
|
||||
|
||||
|
||||
class TestCreateSentimentWorker:
|
||||
"""Tests for create_sentiment_worker factory function."""
|
||||
|
||||
def test_create_sentiment_worker(self):
|
||||
"""Test creating a sentiment worker."""
|
||||
worker = create_sentiment_worker()
|
||||
|
||||
assert isinstance(worker, SentimentWorker)
|
||||
140
backend/tests/test_tweet_model.py
Normal file
140
backend/tests/test_tweet_model.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Unit tests for Tweet model.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tweet import Tweet
|
||||
from app.database import Base, engine, SessionLocal
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""Create a fresh database session for each test."""
|
||||
# Create tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create session
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
class TestTweetModel:
|
||||
"""Test Tweet SQLAlchemy model."""
|
||||
|
||||
def test_tweet_creation(self, db_session: Session):
|
||||
"""Test creating a tweet in the database."""
|
||||
tweet = Tweet(
|
||||
tweet_id="123456789",
|
||||
text="Test tweet content",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
retweet_count=10,
|
||||
like_count=20,
|
||||
match_id=1,
|
||||
source="twitter"
|
||||
)
|
||||
|
||||
db_session.add(tweet)
|
||||
db_session.commit()
|
||||
db_session.refresh(tweet)
|
||||
|
||||
assert tweet.id is not None
|
||||
assert tweet.tweet_id == "123456789"
|
||||
assert tweet.text == "Test tweet content"
|
||||
assert tweet.retweet_count == 10
|
||||
assert tweet.like_count == 20
|
||||
assert tweet.match_id == 1
|
||||
assert tweet.source == "twitter"
|
||||
|
||||
def test_tweet_defaults(self, db_session: Session):
|
||||
"""Test tweet default values."""
|
||||
tweet = Tweet(
|
||||
tweet_id="987654321",
|
||||
text="Another test tweet",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(tweet)
|
||||
db_session.commit()
|
||||
db_session.refresh(tweet)
|
||||
|
||||
assert tweet.retweet_count == 0
|
||||
assert tweet.like_count == 0
|
||||
assert tweet.source == "twitter"
|
||||
assert tweet.match_id is None
|
||||
|
||||
def test_tweet_unique_constraint(self, db_session: Session):
|
||||
"""Test that tweet_id must be unique."""
|
||||
tweet1 = Tweet(
|
||||
tweet_id="111111111",
|
||||
text="First tweet",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
tweet2 = Tweet(
|
||||
tweet_id="111111111", # Same tweet_id
|
||||
text="Second tweet",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db_session.add(tweet1)
|
||||
db_session.commit()
|
||||
|
||||
db_session.add(tweet2)
|
||||
|
||||
with pytest.raises(Exception): # IntegrityError expected
|
||||
db_session.commit()
|
||||
|
||||
def test_tweet_to_dict(self, db_session: Session):
|
||||
"""Test converting tweet to dictionary."""
|
||||
tweet = Tweet(
|
||||
tweet_id="222222222",
|
||||
text="Test tweet for dict",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
retweet_count=5,
|
||||
like_count=10,
|
||||
match_id=2,
|
||||
source="reddit"
|
||||
)
|
||||
|
||||
db_session.add(tweet)
|
||||
db_session.commit()
|
||||
db_session.refresh(tweet)
|
||||
|
||||
tweet_dict = tweet.to_dict()
|
||||
|
||||
assert tweet_dict['tweet_id'] == "222222222"
|
||||
assert tweet_dict['text'] == "Test tweet for dict"
|
||||
assert tweet_dict['retweet_count'] == 5
|
||||
assert tweet_dict['like_count'] == 10
|
||||
assert tweet_dict['match_id'] == 2
|
||||
assert tweet_dict['source'] == "reddit"
|
||||
assert 'id' in tweet_dict
|
||||
assert 'created_at' in tweet_dict
|
||||
|
||||
def test_tweet_repr(self, db_session: Session):
|
||||
"""Test tweet __repr__ method."""
|
||||
tweet = Tweet(
|
||||
tweet_id="333333333",
|
||||
text="Test tweet repr",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
match_id=3
|
||||
)
|
||||
|
||||
db_session.add(tweet)
|
||||
db_session.commit()
|
||||
db_session.refresh(tweet)
|
||||
|
||||
repr_str = repr(tweet)
|
||||
|
||||
assert "Tweet" in repr_str
|
||||
assert "id=" in repr_str
|
||||
assert "tweet_id=333333333" in repr_str
|
||||
assert "match_id=3" in repr_str
|
||||
186
backend/tests/test_twitter_scraper.py
Normal file
186
backend/tests/test_twitter_scraper.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Unit tests for Twitter scraper.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from app.scrapers.twitter_scraper import (
|
||||
TwitterScraper,
|
||||
RateLimitInfo,
|
||||
TweetData,
|
||||
create_twitter_scraper
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tweepy_client():
|
||||
"""Mock tweepy Client."""
|
||||
with patch('app.scrapers.twitter_scraper.tweepy.Client') as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_bearer_token():
|
||||
"""Test bearer token."""
|
||||
return "test_bearer_token_12345"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_scraper(test_bearer_token):
|
||||
"""Create test scraper instance."""
|
||||
with patch('app.scrapers.twitter_scraper.tweepy.Client'):
|
||||
scraper = TwitterScraper(
|
||||
bearer_token=test_bearer_token,
|
||||
max_tweets_per_hour=100,
|
||||
rate_limit_alert_threshold=0.9,
|
||||
vip_match_ids=[1, 2, 3]
|
||||
)
|
||||
return scraper
|
||||
|
||||
|
||||
class TestRateLimitInfo:
|
||||
"""Test RateLimitInfo dataclass."""
|
||||
|
||||
def test_usage_percentage(self):
|
||||
"""Test usage percentage calculation."""
|
||||
info = RateLimitInfo(remaining=100, limit=1000, reset_time=None)
|
||||
assert info.usage_percentage == 0.9
|
||||
|
||||
info = RateLimitInfo(remaining=0, limit=1000, reset_time=None)
|
||||
assert info.usage_percentage == 1.0
|
||||
|
||||
info = RateLimitInfo(remaining=1000, limit=1000, reset_time=None)
|
||||
assert info.usage_percentage == 0.0
|
||||
|
||||
|
||||
class TestTweetData:
|
||||
"""Test TweetData dataclass."""
|
||||
|
||||
def test_tweet_data_creation(self):
|
||||
"""Test creating TweetData instance."""
|
||||
tweet = TweetData(
|
||||
tweet_id="123456789",
|
||||
text="Test tweet content",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
retweet_count=10,
|
||||
like_count=20,
|
||||
match_id=1
|
||||
)
|
||||
|
||||
assert tweet.tweet_id == "123456789"
|
||||
assert tweet.source == "twitter"
|
||||
assert tweet.match_id == 1
|
||||
assert tweet.retweet_count == 10
|
||||
assert tweet.like_count == 20
|
||||
|
||||
|
||||
class TestTwitterScraper:
|
||||
"""Test TwitterScraper class."""
|
||||
|
||||
def test_scraper_initialization(self, test_bearer_token):
|
||||
"""Test scraper initialization."""
|
||||
with patch('app.scrapers.twitter_scraper.tweepy.Client'):
|
||||
scraper = TwitterScraper(
|
||||
bearer_token=test_bearer_token,
|
||||
max_tweets_per_hour=1000,
|
||||
rate_limit_alert_threshold=0.9,
|
||||
vip_match_ids=[1, 2, 3]
|
||||
)
|
||||
|
||||
assert scraper.bearer_token == test_bearer_token
|
||||
assert scraper.max_tweets_per_hour == 1000
|
||||
assert scraper.rate_limit_alert_threshold == 0.9
|
||||
assert scraper.vip_match_ids == [1, 2, 3]
|
||||
assert scraper.vip_mode_only is False
|
||||
assert scraper.api_calls_made == 0
|
||||
|
||||
def test_check_rate_limit_normal(self, test_scraper):
|
||||
"""Test rate limit check under normal conditions."""
|
||||
test_scraper.api_calls_made = 800 # 80% usage
|
||||
assert test_scraper._check_rate_limit() is True
|
||||
|
||||
def test_check_rate_limit_alert(self, test_scraper, caplog):
|
||||
"""Test rate limit alert at threshold."""
|
||||
test_scraper.api_calls_made = 900 # 90% usage
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
result = test_scraper._check_rate_limit()
|
||||
|
||||
assert result is True
|
||||
assert "Rate limit approaching" in caplog.text
|
||||
|
||||
def test_check_rate_limit_exceeded(self, test_scraper, caplog):
|
||||
"""Test rate limit exceeded."""
|
||||
test_scraper.api_calls_made = 1000 # 100% usage
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
result = test_scraper._check_rate_limit()
|
||||
|
||||
assert result is False
|
||||
assert "Rate limit reached" in caplog.text
|
||||
|
||||
def test_enable_vip_mode_only(self, test_scraper, caplog):
|
||||
"""Test enabling VIP mode."""
|
||||
with caplog.at_level("WARNING"):
|
||||
test_scraper._enable_vip_mode_only()
|
||||
|
||||
assert test_scraper.vip_mode_only is True
|
||||
assert "ENTERING DEGRADED MODE" in caplog.text
|
||||
assert "VIP match IDs:" in caplog.text
|
||||
|
||||
def test_scrape_non_vip_in_vip_mode(self, test_scraper):
|
||||
"""Test scraping non-VIP match when VIP mode is active."""
|
||||
test_scraper.vip_mode_only = True
|
||||
|
||||
with pytest.raises(ValueError, match="Match 4 is not VIP"):
|
||||
test_scraper.scrape_twitter_match(
|
||||
match_id=4,
|
||||
keywords=["test"],
|
||||
max_results=10
|
||||
)
|
||||
|
||||
def test_scrape_vip_in_vip_mode(self, test_scraper, mock_tweepy_client):
|
||||
"""Test scraping VIP match when VIP mode is active."""
|
||||
test_scraper.vip_mode_only = True
|
||||
|
||||
# Mock API response
|
||||
mock_response = Mock()
|
||||
mock_response.data = [
|
||||
Mock(
|
||||
id="123456789",
|
||||
text="Test tweet",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
public_metrics={'retweet_count': 10, 'like_count': 20}
|
||||
)
|
||||
]
|
||||
mock_tweepy_client.return_value.search_recent_tweets.return_value = mock_response
|
||||
|
||||
result = test_scraper.scrape_twitter_match(
|
||||
match_id=1,
|
||||
keywords=["test"],
|
||||
max_results=10
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].tweet_id == "123456789"
|
||||
assert test_scraper.api_calls_made == 1
|
||||
|
||||
|
||||
class TestCreateTwitterScraper:
|
||||
"""Test create_twitter_scraper factory function."""
|
||||
|
||||
def test_factory_function(self, test_bearer_token):
|
||||
"""Test factory function creates scraper."""
|
||||
with patch('app.scrapers.twitter_scraper.TwitterScraper') as MockScraper:
|
||||
mock_instance = Mock()
|
||||
MockScraper.return_value = mock_instance
|
||||
|
||||
result = create_twitter_scraper(
|
||||
bearer_token=test_bearer_token,
|
||||
vip_match_ids=[1, 2, 3]
|
||||
)
|
||||
|
||||
MockScraper.assert_called_once()
|
||||
assert result == mock_instance
|
||||
Reference in New Issue
Block a user