Add downloading for media in tweets
This commit is contained in:
parent
d1ba8b48f3
commit
9734c09426
104
persistence/media_download.go
Normal file
104
persistence/media_download.go
Normal file
@ -0,0 +1,104 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"net/http"
|
||||
"io/ioutil"
|
||||
|
||||
"offline_twitter/scraper"
|
||||
)
|
||||
|
||||
type MediaDownloader interface {
|
||||
Curl(url string, outpath string) error
|
||||
}
|
||||
|
||||
type DefaultDownloader struct {}
|
||||
|
||||
/**
|
||||
* Download a file over HTTP and save it.
|
||||
*
|
||||
* args:
|
||||
* - url: the remote file to download
|
||||
* - outpath: the path on disk to save it to
|
||||
*/
|
||||
func (d DefaultDownloader) Curl(url string, outpath string) error {
|
||||
println(url)
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("Error %s: %s", url, resp.Status)
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error downloading image %s: %s", url, err.Error())
|
||||
}
|
||||
|
||||
err = os.WriteFile(outpath, data, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error writing to path: %s, url: %s: %s", outpath, url, err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads an Image, and if successful, marks it as downloaded in the DB
|
||||
*/
|
||||
func (p Profile) download_tweet_image(img *scraper.Image, downloader MediaDownloader) error {
|
||||
outfile := path.Join(p.ProfileDir, "images", img.LocalFilename)
|
||||
err := downloader.Curl(img.RemoteURL, outfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
img.IsDownloaded = true
|
||||
return p.SaveImage(*img)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Downloads an Video, and if successful, marks it as downloaded in the DB
|
||||
*/
|
||||
func (p Profile) download_tweet_video(v *scraper.Video, downloader MediaDownloader) error {
|
||||
outfile := path.Join(p.ProfileDir, "videos", v.LocalFilename)
|
||||
err := downloader.Curl(v.RemoteURL, outfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.IsDownloaded = true
|
||||
return p.SaveVideo(*v)
|
||||
}
|
||||
|
||||
/**
|
||||
* Download a tweet's video and picture content.
|
||||
*
|
||||
* Wraps the `DownloadTweetContentWithInjector` method with the default (i.e., real) downloader.
|
||||
*/
|
||||
func (p Profile) DownloadTweetContentFor(t *scraper.Tweet) error {
|
||||
return p.DownloadTweetContentWithInjector(t, DefaultDownloader{})
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Enable injecting a custom MediaDownloader (i.e., for testing)
|
||||
*/
|
||||
func (p Profile) DownloadTweetContentWithInjector(t *scraper.Tweet, downloader MediaDownloader) error {
|
||||
for i := range t.Images {
|
||||
err := p.download_tweet_image(&t.Images[i], downloader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for i := range t.Videos {
|
||||
err := p.download_tweet_video(&t.Videos[i], downloader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
t.IsContentDownloaded = true
|
||||
return p.SaveTweet(*t)
|
||||
}
|
72
persistence/media_download_test.go
Normal file
72
persistence/media_download_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package persistence_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"offline_twitter/scraper"
|
||||
)
|
||||
|
||||
type FakeDownloader struct {}
|
||||
func (d FakeDownloader) Curl(url string, outpath string) error { return nil }
|
||||
|
||||
func test_all_downloaded(tweet scraper.Tweet, yes_or_no bool, t *testing.T) {
|
||||
error_msg := map[bool]string{
|
||||
true: "Expected to be downloaded, but it wasn't",
|
||||
false: "Expected not to be downloaded, but it was",
|
||||
}[yes_or_no]
|
||||
|
||||
if len(tweet.Images) != 2 {
|
||||
t.Errorf("Expected %d images, got %d", 2, len(tweet.Images))
|
||||
}
|
||||
if len(tweet.Videos) != 1 {
|
||||
t.Errorf("Expected %d videos, got %d", 1, len(tweet.Videos))
|
||||
}
|
||||
for _, img := range tweet.Images {
|
||||
if img.IsDownloaded != yes_or_no {
|
||||
t.Errorf("%s: ImageID %d", error_msg, img.ID)
|
||||
}
|
||||
}
|
||||
for _, vid := range tweet.Videos {
|
||||
if vid.IsDownloaded != yes_or_no {
|
||||
t.Errorf("Expected not to be downloaded, but it was: VideoID %d", vid.ID)
|
||||
}
|
||||
}
|
||||
if tweet.IsContentDownloaded != yes_or_no {
|
||||
t.Errorf("%s: the tweet", error_msg)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an Image, save it, reload it, and make sure it comes back the same
|
||||
*/
|
||||
func TestDownloadTweetContent(t *testing.T) {
|
||||
profile_path := "test_profiles/TestMediaQueries"
|
||||
profile := create_or_load_profile(profile_path)
|
||||
|
||||
tweet := create_dummy_tweet()
|
||||
|
||||
// Persist the tweet
|
||||
err := profile.SaveTweet(tweet)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save the tweet: %s", err.Error())
|
||||
}
|
||||
|
||||
// Make sure everything is marked "not downloaded"
|
||||
test_all_downloaded(tweet, false, t)
|
||||
|
||||
// Do the (fake) downloading
|
||||
err = profile.DownloadTweetContentWithInjector(&tweet, FakeDownloader{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error running fake download: %s", err.Error())
|
||||
}
|
||||
|
||||
// It should all be marked "yes downloaded" now
|
||||
test_all_downloaded(tweet, true, t)
|
||||
|
||||
// Reload the tweet (check db); should also be "yes downloaded"
|
||||
new_tweet, err := profile.GetTweetById(tweet.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't reload the tweet: %s", err.Error())
|
||||
}
|
||||
test_all_downloaded(new_tweet, true, t)
|
||||
}
|
@ -17,16 +17,17 @@ func (p Profile) SaveTweet(t scraper.Tweet) error {
|
||||
return err
|
||||
}
|
||||
_, err = db.Exec(`
|
||||
insert into tweets (id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to, quoted_tweet, mentions, hashtags)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
insert into tweets (id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to, quoted_tweet, mentions, hashtags, is_content_downloaded)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
on conflict do update
|
||||
set num_likes=?,
|
||||
num_retweets=?,
|
||||
num_replies=?,
|
||||
num_quote_tweets=?
|
||||
num_quote_tweets=?,
|
||||
is_content_downloaded=?
|
||||
`,
|
||||
t.ID, t.UserID, t.Text, t.PostedAt.Unix(), t.NumLikes, t.NumRetweets, t.NumReplies, t.NumQuoteTweets, t.InReplyTo, t.QuotedTweet, scraper.JoinArrayOfHandles(t.Mentions), strings.Join(t.Hashtags, ","),
|
||||
t.NumLikes, t.NumRetweets, t.NumReplies, t.NumQuoteTweets,
|
||||
t.ID, t.UserID, t.Text, t.PostedAt.Unix(), t.NumLikes, t.NumRetweets, t.NumReplies, t.NumQuoteTweets, t.InReplyTo, t.QuotedTweet, scraper.JoinArrayOfHandles(t.Mentions), strings.Join(t.Hashtags, ","), t.IsContentDownloaded,
|
||||
t.NumLikes, t.NumRetweets, t.NumReplies, t.NumQuoteTweets, t.IsContentDownloaded,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
@ -106,7 +107,7 @@ func (p Profile) GetTweetById(id scraper.TweetID) (scraper.Tweet, error) {
|
||||
db := p.DB
|
||||
|
||||
stmt, err := db.Prepare(`
|
||||
select id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to, quoted_tweet, mentions, hashtags
|
||||
select id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to, quoted_tweet, mentions, hashtags, is_content_downloaded
|
||||
from tweets
|
||||
where id = ?
|
||||
`)
|
||||
@ -122,7 +123,7 @@ func (p Profile) GetTweetById(id scraper.TweetID) (scraper.Tweet, error) {
|
||||
var hashtags string
|
||||
|
||||
row := stmt.QueryRow(id)
|
||||
err = row.Scan(&t.ID, &t.UserID, &t.Text, &postedAt, &t.NumLikes, &t.NumRetweets, &t.NumReplies, &t.NumQuoteTweets, &t.InReplyTo, &t.QuotedTweet, &mentions, &hashtags)
|
||||
err = row.Scan(&t.ID, &t.UserID, &t.Text, &postedAt, &t.NumLikes, &t.NumRetweets, &t.NumReplies, &t.NumQuoteTweets, &t.InReplyTo, &t.QuotedTweet, &mentions, &hashtags, &t.IsContentDownloaded)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ func TestSaveAndLoadTweet(t *testing.T) {
|
||||
profile := create_or_load_profile(profile_path)
|
||||
|
||||
tweet := create_dummy_tweet()
|
||||
tweet.IsContentDownloaded = true
|
||||
|
||||
// Save the tweet
|
||||
err := profile.SaveTweet(tweet)
|
||||
|
@ -29,6 +29,8 @@ type Tweet struct {
|
||||
Mentions []UserHandle
|
||||
Hashtags []string
|
||||
QuotedTweet TweetID
|
||||
|
||||
IsContentDownloaded bool
|
||||
}
|
||||
|
||||
|
||||
|
@ -35,6 +35,8 @@ type User struct {
|
||||
BannerImageUrl string
|
||||
PinnedTweetID TweetID
|
||||
PinnedTweet *Tweet
|
||||
|
||||
IsContentDownloaded bool
|
||||
}
|
||||
|
||||
func (u User) String() string {
|
||||
|
Loading…
x
Reference in New Issue
Block a user