382 lines
13 KiB
Python
382 lines
13 KiB
Python
"""
|
||
Unit tests for Reddit scraper.
|
||
"""
|
||
|
||
import pytest
|
||
from datetime import datetime, timezone
|
||
from unittest.mock import Mock, patch, MagicMock
|
||
import praw.exceptions
|
||
|
||
from app.scrapers.reddit_scraper import (
|
||
RedditScraper,
|
||
RedditPostData,
|
||
RedditCommentData,
|
||
create_reddit_scraper
|
||
)
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_praw():
|
||
"""Mock praw.Reddit."""
|
||
with patch('app.scrapers.reddit_scraper.praw.Reddit') as mock:
|
||
yield mock
|
||
|
||
|
||
@pytest.fixture
|
||
def test_client_id():
|
||
"""Test Reddit client ID."""
|
||
return "test_client_id"
|
||
|
||
|
||
@pytest.fixture
|
||
def test_client_secret():
|
||
"""Test Reddit client secret."""
|
||
return "test_client_secret"
|
||
|
||
|
||
@pytest.fixture
|
||
def test_scraper(test_client_id, test_client_secret):
|
||
"""Create test scraper instance."""
|
||
with patch('app.scrapers.reddit_scraper.praw.Reddit'):
|
||
scraper = RedditScraper(
|
||
client_id=test_client_id,
|
||
client_secret=test_client_secret,
|
||
subreddits=["soccer", "football"],
|
||
max_posts_per_subreddit=100,
|
||
max_comments_per_post=50
|
||
)
|
||
return scraper
|
||
|
||
|
||
class TestRedditPostData:
|
||
"""Test RedditPostData dataclass."""
|
||
|
||
def test_reddit_post_data_creation(self):
|
||
"""Test creating RedditPostData instance."""
|
||
post = RedditPostData(
|
||
post_id="abc123",
|
||
title="Test post title",
|
||
text="Test post content",
|
||
upvotes=100,
|
||
created_at=datetime.now(timezone.utc),
|
||
match_id=1,
|
||
subreddit="soccer"
|
||
)
|
||
|
||
assert post.post_id == "abc123"
|
||
assert post.title == "Test post title"
|
||
assert post.text == "Test post content"
|
||
assert post.upvotes == 100
|
||
assert post.match_id == 1
|
||
assert post.subreddit == "soccer"
|
||
assert post.source == "reddit"
|
||
|
||
|
||
class TestRedditCommentData:
|
||
"""Test RedditCommentData dataclass."""
|
||
|
||
def test_reddit_comment_data_creation(self):
|
||
"""Test creating RedditCommentData instance."""
|
||
comment = RedditCommentData(
|
||
comment_id="def456",
|
||
post_id="abc123",
|
||
text="Test comment content",
|
||
upvotes=50,
|
||
created_at=datetime.now(timezone.utc)
|
||
)
|
||
|
||
assert comment.comment_id == "def456"
|
||
assert comment.post_id == "abc123"
|
||
assert comment.text == "Test comment content"
|
||
assert comment.upvotes == 50
|
||
assert comment.source == "reddit"
|
||
|
||
|
||
class TestRedditScraper:
|
||
"""Test RedditScraper class."""
|
||
|
||
def test_scraper_initialization(self, test_client_id, test_client_secret):
|
||
"""Test scraper initialization."""
|
||
with patch('app.scrapers.reddit_scraper.praw.Reddit'):
|
||
scraper = RedditScraper(
|
||
client_id=test_client_id,
|
||
client_secret=test_client_secret,
|
||
subreddits=["soccer", "football"],
|
||
max_posts_per_subreddit=100,
|
||
max_comments_per_post=50
|
||
)
|
||
|
||
assert scraper.client_id == test_client_id
|
||
assert scraper.client_secret == test_client_secret
|
||
assert scraper.subreddits == ["soccer", "football"]
|
||
assert scraper.max_posts_per_subreddit == 100
|
||
assert scraper.max_comments_per_post == 50
|
||
|
||
def test_verify_authentication_success(self, mock_praw, caplog):
|
||
"""Test successful authentication verification."""
|
||
mock_reddit = Mock()
|
||
mock_reddit.user.me.return_value = Mock(name="test_user")
|
||
mock_praw.return_value = mock_reddit
|
||
|
||
with caplog.at_level("INFO"):
|
||
scraper = RedditScraper(
|
||
client_id="test_id",
|
||
client_secret="test_secret",
|
||
subreddits=["soccer"]
|
||
)
|
||
|
||
assert "✅ Reddit API authenticated successfully" in caplog.text
|
||
|
||
def test_verify_authentication_failure(self, mock_praw, caplog):
|
||
"""Test failed authentication verification."""
|
||
mock_reddit = Mock()
|
||
mock_reddit.user.me.side_effect = Exception("Auth failed")
|
||
mock_praw.return_value = mock_reddit
|
||
|
||
with pytest.raises(Exception, match="Reddit API authentication failed"):
|
||
RedditScraper(
|
||
client_id="invalid_id",
|
||
client_secret="invalid_secret",
|
||
subreddits=["soccer"]
|
||
)
|
||
|
||
def test_scrape_posts_empty(self, test_scraper, mock_praw, caplog):
|
||
"""Test scraping posts with no results."""
|
||
mock_subreddit = Mock()
|
||
mock_subreddit.new.return_value = []
|
||
mock_reddit = Mock()
|
||
mock_reddit.subreddit.return_value = mock_subreddit
|
||
mock_praw.return_value = mock_reddit
|
||
|
||
with caplog.at_level("INFO"):
|
||
result = test_scraper.scrape_posts(
|
||
subreddit="soccer",
|
||
match_id=1,
|
||
keywords=["test"]
|
||
)
|
||
|
||
assert result == []
|
||
assert "ℹ️ No posts found" in caplog.text
|
||
|
||
def test_scrape_posts_success(self, test_scraper, mock_praw):
|
||
"""Test successful post scraping."""
|
||
# Mock post
|
||
mock_post = Mock()
|
||
mock_post.id = "abc123"
|
||
mock_post.title = "Test match discussion"
|
||
mock_post.selftext = "Great match today!"
|
||
mock_post.score = 100
|
||
mock_post.created_utc = 1700000000.0
|
||
|
||
mock_subreddit = Mock()
|
||
mock_subreddit.new.return_value = [mock_post]
|
||
mock_reddit = Mock()
|
||
mock_reddit.subreddit.return_value = mock_subreddit
|
||
mock_praw.return_value = mock_reddit
|
||
|
||
result = test_scraper.scrape_posts(
|
||
subreddit="soccer",
|
||
match_id=1,
|
||
keywords=["test"]
|
||
)
|
||
|
||
assert len(result) == 1
|
||
assert result[0].post_id == "abc123"
|
||
assert result[0].title == "Test match discussion"
|
||
assert result[0].upvotes == 100
|
||
|
||
def test_scrape_comments_success(self, test_scraper, mock_praw):
|
||
"""Test successful comment scraping."""
|
||
# Mock comment
|
||
mock_comment = Mock()
|
||
mock_comment.id = "def456"
|
||
mock_comment.body = "Great goal!"
|
||
mock_comment.score = 50
|
||
mock_comment.created_utc = 1700000000.0
|
||
|
||
mock_post = Mock()
|
||
mock_post.comments.list.return_value = [mock_comment]
|
||
|
||
result = test_scraper.scrape_comments(post_id="abc123", post=mock_post)
|
||
|
||
assert len(result) == 1
|
||
assert result[0].comment_id == "def456"
|
||
assert result[0].text == "Great goal!"
|
||
assert result[0].upvotes == 50
|
||
|
||
def test_save_posts_to_db(self, test_scraper, mock_praw):
|
||
"""Test saving posts to database."""
|
||
from app.models.reddit_post import RedditPost
|
||
|
||
# Mock posts data
|
||
posts_data = [
|
||
RedditPostData(
|
||
post_id="abc123",
|
||
title="Test post",
|
||
text="Test content",
|
||
upvotes=100,
|
||
created_at=datetime.now(timezone.utc),
|
||
match_id=1,
|
||
subreddit="soccer"
|
||
)
|
||
]
|
||
|
||
# Mock database session
|
||
mock_db = Mock()
|
||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||
|
||
# Mock model
|
||
mock_post_instance = Mock()
|
||
mock_post_instance.id = 1
|
||
|
||
with patch('app.scrapers.reddit_scraper.RedditPost', return_value=mock_post_instance):
|
||
test_scraper.save_posts_to_db(posts_data, mock_db)
|
||
|
||
mock_db.add.assert_called_once()
|
||
mock_db.commit.assert_called_once()
|
||
|
||
def test_save_comments_to_db(self, test_scraper):
|
||
"""Test saving comments to database."""
|
||
# Mock comments data
|
||
comments_data = [
|
||
RedditCommentData(
|
||
comment_id="def456",
|
||
post_id="abc123",
|
||
text="Test comment",
|
||
upvotes=50,
|
||
created_at=datetime.now(timezone.utc)
|
||
)
|
||
]
|
||
|
||
# Mock database session
|
||
mock_db = Mock()
|
||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||
|
||
# Mock model
|
||
mock_comment_instance = Mock()
|
||
mock_comment_instance.id = 1
|
||
|
||
with patch('app.scrapers.reddit_scraper.RedditComment', return_value=mock_comment_instance):
|
||
test_scraper.save_comments_to_db(comments_data, mock_db)
|
||
|
||
mock_db.add.assert_called_once()
|
||
mock_db.commit.assert_called_once()
|
||
|
||
def test_save_posts_to_db_existing(self, test_scraper):
|
||
"""Test saving posts when post already exists."""
|
||
posts_data = [
|
||
RedditPostData(
|
||
post_id="abc123",
|
||
title="Test post",
|
||
text="Test content",
|
||
upvotes=100,
|
||
created_at=datetime.now(timezone.utc),
|
||
match_id=1,
|
||
subreddit="soccer"
|
||
)
|
||
]
|
||
|
||
mock_db = Mock()
|
||
mock_db.query.return_value.filter.return_value.first.return_value = Mock()
|
||
|
||
test_scraper.save_posts_to_db(posts_data, mock_db)
|
||
|
||
# Should not add or commit if post already exists
|
||
mock_db.add.assert_not_called()
|
||
mock_db.commit.assert_called_once()
|
||
|
||
def test_scrape_posts_api_error_continues(self, test_scraper, mock_praw, caplog):
|
||
"""Test that scraper continues on API error."""
|
||
mock_subreddit = Mock()
|
||
mock_subreddit.new.side_effect = praw.exceptions.PRAWException("API Error")
|
||
mock_reddit = Mock()
|
||
mock_reddit.subreddit.return_value = mock_subreddit
|
||
mock_praw.return_value = mock_reddit
|
||
|
||
with caplog.at_level("ERROR"):
|
||
result = test_scraper.scrape_posts(
|
||
subreddit="soccer",
|
||
match_id=1,
|
||
keywords=["test"]
|
||
)
|
||
|
||
assert result == []
|
||
assert "Reddit API error" in caplog.text
|
||
|
||
def test_scrape_comments_api_error_continues(self, test_scraper, caplog):
|
||
"""Test that scraper continues on comment API error."""
|
||
mock_post = Mock()
|
||
mock_post.comments.list.side_effect = praw.exceptions.PRAWException("API Error")
|
||
|
||
with caplog.at_level("ERROR"):
|
||
result = test_scraper.scrape_comments(post_id="abc123", post=mock_post)
|
||
|
||
assert result == []
|
||
assert "Reddit API error" in caplog.text
|
||
|
||
def test_save_posts_db_error_rollback(self, test_scraper, caplog):
|
||
"""Test that database errors trigger rollback."""
|
||
posts_data = [
|
||
RedditPostData(
|
||
post_id="abc123",
|
||
title="Test post",
|
||
text="Test content",
|
||
upvotes=100,
|
||
created_at=datetime.now(timezone.utc),
|
||
match_id=1,
|
||
subreddit="soccer"
|
||
)
|
||
]
|
||
|
||
mock_db = Mock()
|
||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||
mock_db.commit.side_effect = Exception("Database error")
|
||
|
||
with caplog.at_level("ERROR"):
|
||
with pytest.raises(Exception):
|
||
test_scraper.save_posts_to_db(posts_data, mock_db)
|
||
|
||
mock_db.rollback.assert_called_once()
|
||
assert "Failed to save Reddit posts" in caplog.text
|
||
|
||
def test_scrape_reddit_match_continues_on_subreddit_error(
|
||
self, test_scraper, mock_praw, caplog
|
||
):
|
||
"""Test that scraper continues with other subreddits on error."""
|
||
# Mock first subreddit to error, second to succeed
|
||
mock_subreddit1 = Mock()
|
||
mock_subreddit1.new.side_effect = praw.exceptions.PRAWException("API Error")
|
||
mock_subreddit2 = Mock()
|
||
mock_subreddit2.new.return_value = []
|
||
|
||
mock_reddit = Mock()
|
||
mock_reddit.subreddit.side_effect = [mock_subreddit1, mock_subreddit2]
|
||
mock_praw.return_value = mock_reddit
|
||
|
||
test_scraper.subreddits = ["soccer", "football"]
|
||
|
||
with caplog.at_level("ERROR"):
|
||
result = test_scraper.scrape_reddit_match(match_id=1)
|
||
|
||
assert result['posts'] == []
|
||
assert result['comments'] == []
|
||
assert "Continuing with other sources" in caplog.text
|
||
|
||
|
||
class TestCreateRedditScraper:
|
||
"""Test create_reddit_scraper factory function."""
|
||
|
||
def test_factory_function(self, test_client_id, test_client_secret):
|
||
"""Test factory function creates scraper."""
|
||
with patch('app.scrapers.reddit_scraper.RedditScraper') as MockScraper:
|
||
mock_instance = Mock()
|
||
MockScraper.return_value = mock_instance
|
||
|
||
result = create_reddit_scraper(
|
||
client_id=test_client_id,
|
||
client_secret=test_client_secret
|
||
)
|
||
|
||
MockScraper.assert_called_once()
|
||
assert result == mock_instance
|