From 9734c09426b8503eb5d9f582cdc29c30ea38447e Mon Sep 17 00:00:00 2001 From: Alessio Date: Sat, 7 Aug 2021 16:51:38 -0700 Subject: [PATCH] Add downloading for media in tweets --- persistence/media_download.go | 104 +++++++++++++++++++++++++++++ persistence/media_download_test.go | 72 ++++++++++++++++++++ persistence/tweet_queries.go | 15 +++-- persistence/tweet_queries_test.go | 1 + scraper/tweet.go | 2 + scraper/user.go | 2 + 6 files changed, 189 insertions(+), 7 deletions(-) create mode 100644 persistence/media_download.go create mode 100644 persistence/media_download_test.go diff --git a/persistence/media_download.go b/persistence/media_download.go new file mode 100644 index 0000000..bfc7ed4 --- /dev/null +++ b/persistence/media_download.go @@ -0,0 +1,104 @@ +package persistence + +import ( + "fmt" + "os" + "path" + "net/http" + "io/ioutil" + + "offline_twitter/scraper" +) + +type MediaDownloader interface { + Curl(url string, outpath string) error +} + +type DefaultDownloader struct {} + +/** + * Download a file over HTTP and save it. + * + * args: + * - url: the remote file to download + * - outpath: the path on disk to save it to + */ +func (d DefaultDownloader) Curl(url string, outpath string) error { + println(url) + resp, err := http.Get(url) + if err != nil { + return err + } + if resp.StatusCode != 200 { + return fmt.Errorf("Error %s: %s", url, resp.Status) + } + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("Error downloading image %s: %s", url, err.Error()) + } + + err = os.WriteFile(outpath, data, 0644) + if err != nil { + return fmt.Errorf("Error writing to path: %s, url: %s: %s", outpath, url, err.Error()) + } + return nil +} + +/** + * Downloads an Image, and if successful, marks it as downloaded in the DB + */ +func (p Profile) download_tweet_image(img *scraper.Image, downloader MediaDownloader) error { + outfile := path.Join(p.ProfileDir, "images", img.LocalFilename) + err := downloader.Curl(img.RemoteURL, outfile) + if err != nil { + return err + } + img.IsDownloaded = true + return p.SaveImage(*img) +} + + +/** + * Downloads an Video, and if successful, marks it as downloaded in the DB + */ +func (p Profile) download_tweet_video(v *scraper.Video, downloader MediaDownloader) error { + outfile := path.Join(p.ProfileDir, "videos", v.LocalFilename) + err := downloader.Curl(v.RemoteURL, outfile) + if err != nil { + return err + } + v.IsDownloaded = true + return p.SaveVideo(*v) +} + +/** + * Download a tweet's video and picture content. + * + * Wraps the `DownloadTweetContentWithInjector` method with the default (i.e., real) downloader. + */ +func (p Profile) DownloadTweetContentFor(t *scraper.Tweet) error { + return p.DownloadTweetContentWithInjector(t, DefaultDownloader{}) +} + + +/** + * Enable injecting a custom MediaDownloader (i.e., for testing) + */ +func (p Profile) DownloadTweetContentWithInjector(t *scraper.Tweet, downloader MediaDownloader) error { + for i := range t.Images { + err := p.download_tweet_image(&t.Images[i], downloader) + if err != nil { + return err + } + } + + for i := range t.Videos { + err := p.download_tweet_video(&t.Videos[i], downloader) + if err != nil { + return err + } + } + t.IsContentDownloaded = true + return p.SaveTweet(*t) +} diff --git a/persistence/media_download_test.go b/persistence/media_download_test.go new file mode 100644 index 0000000..267a0e1 --- /dev/null +++ b/persistence/media_download_test.go @@ -0,0 +1,72 @@ +package persistence_test + +import ( + "testing" + + "offline_twitter/scraper" +) + +type FakeDownloader struct {} +func (d FakeDownloader) Curl(url string, outpath string) error { return nil } + +func test_all_downloaded(tweet scraper.Tweet, yes_or_no bool, t *testing.T) { + error_msg := map[bool]string{ + true: "Expected to be downloaded, but it wasn't", + false: "Expected not to be downloaded, but it was", + }[yes_or_no] + + if len(tweet.Images) != 2 { + t.Errorf("Expected %d images, got %d", 2, len(tweet.Images)) + } + if len(tweet.Videos) != 1 { + t.Errorf("Expected %d videos, got %d", 1, len(tweet.Videos)) + } + for _, img := range tweet.Images { + if img.IsDownloaded != yes_or_no { + t.Errorf("%s: ImageID %d", error_msg, img.ID) + } + } + for _, vid := range tweet.Videos { + if vid.IsDownloaded != yes_or_no { + t.Errorf("Expected not to be downloaded, but it was: VideoID %d", vid.ID) + } + } + if tweet.IsContentDownloaded != yes_or_no { + t.Errorf("%s: the tweet", error_msg) + } +} + +/** + * Create an Image, save it, reload it, and make sure it comes back the same + */ +func TestDownloadTweetContent(t *testing.T) { + profile_path := "test_profiles/TestMediaQueries" + profile := create_or_load_profile(profile_path) + + tweet := create_dummy_tweet() + + // Persist the tweet + err := profile.SaveTweet(tweet) + if err != nil { + t.Fatalf("Failed to save the tweet: %s", err.Error()) + } + + // Make sure everything is marked "not downloaded" + test_all_downloaded(tweet, false, t) + + // Do the (fake) downloading + err = profile.DownloadTweetContentWithInjector(&tweet, FakeDownloader{}) + if err != nil { + t.Fatalf("Error running fake download: %s", err.Error()) + } + + // It should all be marked "yes downloaded" now + test_all_downloaded(tweet, true, t) + + // Reload the tweet (check db); should also be "yes downloaded" + new_tweet, err := profile.GetTweetById(tweet.ID) + if err != nil { + t.Fatalf("Couldn't reload the tweet: %s", err.Error()) + } + test_all_downloaded(new_tweet, true, t) +} diff --git a/persistence/tweet_queries.go b/persistence/tweet_queries.go index 0e9d28f..82f52c2 100644 --- a/persistence/tweet_queries.go +++ b/persistence/tweet_queries.go @@ -17,16 +17,17 @@ func (p Profile) SaveTweet(t scraper.Tweet) error { return err } _, err = db.Exec(` - insert into tweets (id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to, quoted_tweet, mentions, hashtags) - values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + insert into tweets (id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to, quoted_tweet, mentions, hashtags, is_content_downloaded) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) on conflict do update set num_likes=?, num_retweets=?, num_replies=?, - num_quote_tweets=? + num_quote_tweets=?, + is_content_downloaded=? `, - t.ID, t.UserID, t.Text, t.PostedAt.Unix(), t.NumLikes, t.NumRetweets, t.NumReplies, t.NumQuoteTweets, t.InReplyTo, t.QuotedTweet, scraper.JoinArrayOfHandles(t.Mentions), strings.Join(t.Hashtags, ","), - t.NumLikes, t.NumRetweets, t.NumReplies, t.NumQuoteTweets, + t.ID, t.UserID, t.Text, t.PostedAt.Unix(), t.NumLikes, t.NumRetweets, t.NumReplies, t.NumQuoteTweets, t.InReplyTo, t.QuotedTweet, scraper.JoinArrayOfHandles(t.Mentions), strings.Join(t.Hashtags, ","), t.IsContentDownloaded, + t.NumLikes, t.NumRetweets, t.NumReplies, t.NumQuoteTweets, t.IsContentDownloaded, ) if err != nil { @@ -106,7 +107,7 @@ func (p Profile) GetTweetById(id scraper.TweetID) (scraper.Tweet, error) { db := p.DB stmt, err := db.Prepare(` - select id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to, quoted_tweet, mentions, hashtags + select id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to, quoted_tweet, mentions, hashtags, is_content_downloaded from tweets where id = ? `) @@ -122,7 +123,7 @@ func (p Profile) GetTweetById(id scraper.TweetID) (scraper.Tweet, error) { var hashtags string row := stmt.QueryRow(id) - err = row.Scan(&t.ID, &t.UserID, &t.Text, &postedAt, &t.NumLikes, &t.NumRetweets, &t.NumReplies, &t.NumQuoteTweets, &t.InReplyTo, &t.QuotedTweet, &mentions, &hashtags) + err = row.Scan(&t.ID, &t.UserID, &t.Text, &postedAt, &t.NumLikes, &t.NumRetweets, &t.NumReplies, &t.NumQuoteTweets, &t.InReplyTo, &t.QuotedTweet, &mentions, &hashtags, &t.IsContentDownloaded) if err != nil { return t, err } diff --git a/persistence/tweet_queries_test.go b/persistence/tweet_queries_test.go index 878452b..eb44df5 100644 --- a/persistence/tweet_queries_test.go +++ b/persistence/tweet_queries_test.go @@ -15,6 +15,7 @@ func TestSaveAndLoadTweet(t *testing.T) { profile := create_or_load_profile(profile_path) tweet := create_dummy_tweet() + tweet.IsContentDownloaded = true // Save the tweet err := profile.SaveTweet(tweet) diff --git a/scraper/tweet.go b/scraper/tweet.go index 26a38d8..a3c20e8 100644 --- a/scraper/tweet.go +++ b/scraper/tweet.go @@ -29,6 +29,8 @@ type Tweet struct { Mentions []UserHandle Hashtags []string QuotedTweet TweetID + + IsContentDownloaded bool } diff --git a/scraper/user.go b/scraper/user.go index a31b988..206507d 100644 --- a/scraper/user.go +++ b/scraper/user.go @@ -35,6 +35,8 @@ type User struct { BannerImageUrl string PinnedTweetID TweetID PinnedTweet *Tweet + + IsContentDownloaded bool } func (u User) String() string {