diff --git a/media_downloader.py b/media_downloader.py index 690a7fd..dddd098 100644 --- a/media_downloader.py +++ b/media_downloader.py @@ -118,7 +118,6 @@ async def _get_media_meta( # pylint: disable = C0209 file_format = media_obj.mime_type.split("/")[-1] # type: ignore file_name: str = os.path.join( - THIS_DIR, _type, "{}_{}.{}".format( _type, @@ -127,9 +126,7 @@ async def _get_media_meta( ), ) else: - file_name = os.path.join( - THIS_DIR, _type, getattr(media_obj, "file_name", None) or "" - ) + file_name = os.path.join(_type, getattr(media_obj, "file_name", None) or "") return file_name, file_format @@ -138,6 +135,7 @@ async def download_media( message: pyrogram.types.Message, media_types: List[str], file_formats: dict, + upload_dir: str = THIS_DIR, ): """ Download media from Telegram. @@ -179,6 +177,7 @@ async def download_media( if _media is None: continue file_name, file_format = await _get_media_meta(_media, _type) + file_name = os.path.join(upload_dir, file_name) if _can_download(_type, file_formats, file_format): if _is_exist(file_name): file_name = get_next_name(file_name) @@ -242,6 +241,7 @@ async def process_messages( messages: List[pyrogram.types.Message], media_types: List[str], file_formats: dict, + upload_dir: str, ) -> int: """ Download media from Telegram. @@ -273,7 +273,7 @@ async def process_messages( """ message_ids = await asyncio.gather( *[ - download_media(client, message, media_types, file_formats) + download_media(client, message, media_types, file_formats, upload_dir=upload_dir) for message in messages ] ) @@ -324,6 +324,8 @@ async def begin_import(config: dict, pagination_limit: int) -> dict: pagination_count += 1 messages_list.append(message) + upload_dir: str = config.get("upload_dir", THIS_DIR) + async for message in messages_iter: # type: ignore if pagination_count != pagination_limit: pagination_count += 1 @@ -334,6 +336,7 @@ async def begin_import(config: dict, pagination_limit: int) -> dict: messages_list, config["media_types"], config["file_formats"], + upload_dir, ) pagination_count = 0 messages_list = [] @@ -346,6 +349,7 @@ async def begin_import(config: dict, pagination_limit: int) -> dict: messages_list, config["media_types"], config["file_formats"], + upload_dir, ) await client.stop() diff --git a/tests/test_media_downloader.py b/tests/test_media_downloader.py index 33d7bda..0e9412b 100644 --- a/tests/test_media_downloader.py +++ b/tests/test_media_downloader.py @@ -118,9 +118,11 @@ async def async_get_media_meta(message_media, _type): return result -async def async_download_media(client, message, media_types, file_formats): - result = await download_media(client, message, media_types, file_formats) - return result +async def async_download_media(client, message, media_types, file_formats, upload_dir = None): + if upload_dir != None: + return await download_media(client, message, media_types, file_formats, upload_dir=upload_dir) + + return await download_media(client, message, media_types, file_formats) async def async_begin_import(conf, pagination_limit): @@ -132,8 +134,8 @@ async def mock_process_message(*args, **kwargs): return 5 -async def async_process_messages(client, messages, media_types, file_formats): - result = await process_messages(client, messages, media_types, file_formats) +async def async_process_messages(client, messages, media_types, file_formats, upload_dir): + result = await process_messages(client, messages, media_types, file_formats, upload_dir) return result @@ -230,7 +232,6 @@ class MediaDownloaderTestCase(unittest.TestCase): def setUpClass(cls): cls.loop = asyncio.get_event_loop() - @mock.patch("media_downloader.THIS_DIR", new=MOCK_DIR) def test_get_media_meta(self): # Test Voice notes message = MockMessage( @@ -248,7 +249,7 @@ def test_get_media_meta(self): self.assertEqual( ( platform_generic_path( - "/root/project/voice/voice_2019-07-25T14:53:50.ogg" + "voice/voice_2019-07-25T14:53:50.ogg" ), "ogg", ), @@ -266,7 +267,7 @@ def test_get_media_meta(self): ) self.assertEqual( ( - platform_generic_path("/root/project/photo/"), + platform_generic_path("photo/"), None, ), result, @@ -286,7 +287,7 @@ def test_get_media_meta(self): ) self.assertEqual( ( - platform_generic_path("/root/project/document/sample_document.pdf"), + platform_generic_path("document/sample_document.pdf"), "pdf", ), result, @@ -306,7 +307,7 @@ def test_get_media_meta(self): ) self.assertEqual( ( - platform_generic_path("/root/project/audio/sample_audio.mp3"), + platform_generic_path("audio/sample_audio.mp3"), "mp3", ), result, @@ -325,7 +326,7 @@ def test_get_media_meta(self): ) self.assertEqual( ( - platform_generic_path("/root/project/video/"), + platform_generic_path("video/"), "mp4", ), result, @@ -346,7 +347,7 @@ def test_get_media_meta(self): self.assertEqual( ( platform_generic_path( - "/root/project/video_note/video_note_2019-07-25T14:53:50.mp4" + "video_note/video_note_2019-07-25T14:53:50.mp4" ), "mp4", ), @@ -480,6 +481,35 @@ def test_download_media(self, mock_logger, patched_time_sleep): "Message[%d]: Timing out after 3 reties, download skipped.", 11 ) + @mock.patch.object(download_media,"__defaults__", (MOCK_DIR,)) + def test_download_media_default_upload_dir(self): + client = mock.MagicMock() + message = MockMessage( + id=5, + media=True, + video=MockVideo(mime_type="video/mp4"), + ) + self.loop.run_until_complete( + async_download_media( + client, message, ["video"], {"video": ["mp4"]} + ) + ) + client.download_media.assert_called_with(message, file_name=MOCK_DIR + "/video/") + + def test_download_media_custom_upload_dir(self): + client = mock.MagicMock() + message = MockMessage( + id=5, + media=True, + video=MockVideo(mime_type="video/mp4"), + ) + self.loop.run_until_complete( + async_download_media( + client, message, ["video"], {"video": ["mp4"]}, upload_dir= "/custom/path/" + ) + ) + client.download_media.assert_called_with(message, file_name="/custom/path/video/") + @mock.patch("__main__.__builtins__.open", new_callable=mock.mock_open) @mock.patch("media_downloader.yaml", autospec=True) def test_update_config(self, mock_yaml, mock_open): @@ -492,14 +522,17 @@ def test_update_config(self, mock_yaml, mock_open): mock_open.assert_called_with("config.yaml", "w") mock_yaml.dump.assert_called_with(conf, mock.ANY, default_flow_style=False) + @mock.patch("media_downloader.process_messages") @mock.patch("media_downloader.update_config") @mock.patch("media_downloader.pyrogram.Client", new=MockClient) - @mock.patch("media_downloader.process_messages", new=mock_process_message) - def test_begin_import(self, mock_update_config): + @mock.patch("media_downloader.THIS_DIR", new=MOCK_DIR) + def test_begin_import(self, mock_update_config, mock_process_messages): + mock_process_messages.return_value = 5 result = self.loop.run_until_complete(async_begin_import(MOCK_CONF, 3)) conf = copy.deepcopy(MOCK_CONF) conf["last_read_message_id"] = 5 self.assertDictEqual(result, conf) + mock_process_messages.assert_called_with(mock.ANY, mock.ANY, mock.ANY, mock.ANY, MOCK_DIR) def test_process_message(self): client = MockClient() @@ -533,6 +566,7 @@ def test_process_message(self): ], ["voice", "photo"], {"audio": ["all"], "voice": ["all"]}, + MOCK_DIR ) ) self.assertEqual(result, 1216) @@ -574,6 +608,7 @@ def test_process_message_when_file_exists(self, mock_is_exist): ], ["voice", "photo"], {"audio": ["all"], "voice": ["all"]}, + MOCK_DIR ) ) self.assertEqual(result, 1216)