Skip to content

Commit

Permalink
upload dir
Browse files Browse the repository at this point in the history
  • Loading branch information
micronull committed Nov 12, 2023
1 parent 91220f8 commit 20672ac
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 19 deletions.
14 changes: 9 additions & 5 deletions media_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ async def _get_media_meta(
# pylint: disable = C0209
file_format = media_obj.mime_type.split("/")[-1] # type: ignore
file_name: str = os.path.join(
THIS_DIR,
_type,
"{}_{}.{}".format(
_type,
Expand All @@ -127,9 +126,7 @@ async def _get_media_meta(
),
)
else:
file_name = os.path.join(
THIS_DIR, _type, getattr(media_obj, "file_name", None) or ""
)
file_name = os.path.join(_type, getattr(media_obj, "file_name", None) or "")
return file_name, file_format


Expand All @@ -138,6 +135,7 @@ async def download_media(
message: pyrogram.types.Message,
media_types: List[str],
file_formats: dict,
upload_dir: str = THIS_DIR,
):
"""
Download media from Telegram.
Expand Down Expand Up @@ -179,6 +177,7 @@ async def download_media(
if _media is None:
continue
file_name, file_format = await _get_media_meta(_media, _type)
file_name = os.path.join(upload_dir, file_name)
if _can_download(_type, file_formats, file_format):
if _is_exist(file_name):
file_name = get_next_name(file_name)
Expand Down Expand Up @@ -242,6 +241,7 @@ async def process_messages(
messages: List[pyrogram.types.Message],
media_types: List[str],
file_formats: dict,
upload_dir: str,
) -> int:
"""
Download media from Telegram.
Expand Down Expand Up @@ -273,7 +273,7 @@ async def process_messages(
"""
message_ids = await asyncio.gather(
*[
download_media(client, message, media_types, file_formats)
download_media(client, message, media_types, file_formats, upload_dir=upload_dir)
for message in messages
]
)
Expand Down Expand Up @@ -324,6 +324,8 @@ async def begin_import(config: dict, pagination_limit: int) -> dict:
pagination_count += 1
messages_list.append(message)

upload_dir: str = config.get("upload_dir", THIS_DIR)

async for message in messages_iter: # type: ignore
if pagination_count != pagination_limit:
pagination_count += 1
Expand All @@ -334,6 +336,7 @@ async def begin_import(config: dict, pagination_limit: int) -> dict:
messages_list,
config["media_types"],
config["file_formats"],
upload_dir,
)
pagination_count = 0
messages_list = []
Expand All @@ -346,6 +349,7 @@ async def begin_import(config: dict, pagination_limit: int) -> dict:
messages_list,
config["media_types"],
config["file_formats"],
upload_dir,
)

await client.stop()
Expand Down
63 changes: 49 additions & 14 deletions tests/test_media_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ async def async_get_media_meta(message_media, _type):
return result


async def async_download_media(client, message, media_types, file_formats):
result = await download_media(client, message, media_types, file_formats)
return result
async def async_download_media(client, message, media_types, file_formats, upload_dir = None):
if upload_dir != None:
return await download_media(client, message, media_types, file_formats, upload_dir=upload_dir)

return await download_media(client, message, media_types, file_formats)


async def async_begin_import(conf, pagination_limit):
Expand All @@ -132,8 +134,8 @@ async def mock_process_message(*args, **kwargs):
return 5


async def async_process_messages(client, messages, media_types, file_formats):
result = await process_messages(client, messages, media_types, file_formats)
async def async_process_messages(client, messages, media_types, file_formats, upload_dir):
result = await process_messages(client, messages, media_types, file_formats, upload_dir)
return result


Expand Down Expand Up @@ -230,7 +232,6 @@ class MediaDownloaderTestCase(unittest.TestCase):
def setUpClass(cls):
cls.loop = asyncio.get_event_loop()

@mock.patch("media_downloader.THIS_DIR", new=MOCK_DIR)
def test_get_media_meta(self):
# Test Voice notes
message = MockMessage(
Expand All @@ -248,7 +249,7 @@ def test_get_media_meta(self):
self.assertEqual(
(
platform_generic_path(
"/root/project/voice/voice_2019-07-25T14:53:50.ogg"
"voice/voice_2019-07-25T14:53:50.ogg"
),
"ogg",
),
Expand All @@ -266,7 +267,7 @@ def test_get_media_meta(self):
)
self.assertEqual(
(
platform_generic_path("/root/project/photo/"),
platform_generic_path("photo/"),
None,
),
result,
Expand All @@ -286,7 +287,7 @@ def test_get_media_meta(self):
)
self.assertEqual(
(
platform_generic_path("/root/project/document/sample_document.pdf"),
platform_generic_path("document/sample_document.pdf"),
"pdf",
),
result,
Expand All @@ -306,7 +307,7 @@ def test_get_media_meta(self):
)
self.assertEqual(
(
platform_generic_path("/root/project/audio/sample_audio.mp3"),
platform_generic_path("audio/sample_audio.mp3"),
"mp3",
),
result,
Expand All @@ -325,7 +326,7 @@ def test_get_media_meta(self):
)
self.assertEqual(
(
platform_generic_path("/root/project/video/"),
platform_generic_path("video/"),
"mp4",
),
result,
Expand All @@ -346,7 +347,7 @@ def test_get_media_meta(self):
self.assertEqual(
(
platform_generic_path(
"/root/project/video_note/video_note_2019-07-25T14:53:50.mp4"
"video_note/video_note_2019-07-25T14:53:50.mp4"
),
"mp4",
),
Expand Down Expand Up @@ -480,6 +481,35 @@ def test_download_media(self, mock_logger, patched_time_sleep):
"Message[%d]: Timing out after 3 reties, download skipped.", 11
)

@mock.patch.object(download_media,"__defaults__", (MOCK_DIR,))
def test_download_media_default_upload_dir(self):
client = mock.MagicMock()
message = MockMessage(
id=5,
media=True,
video=MockVideo(mime_type="video/mp4"),
)
self.loop.run_until_complete(
async_download_media(
client, message, ["video"], {"video": ["mp4"]}
)
)
client.download_media.assert_called_with(message, file_name=MOCK_DIR + "/video/")

def test_download_media_custom_upload_dir(self):
client = mock.MagicMock()
message = MockMessage(
id=5,
media=True,
video=MockVideo(mime_type="video/mp4"),
)
self.loop.run_until_complete(
async_download_media(
client, message, ["video"], {"video": ["mp4"]}, upload_dir= "/custom/path/"
)
)
client.download_media.assert_called_with(message, file_name="/custom/path/video/")

@mock.patch("__main__.__builtins__.open", new_callable=mock.mock_open)
@mock.patch("media_downloader.yaml", autospec=True)
def test_update_config(self, mock_yaml, mock_open):
Expand All @@ -492,14 +522,17 @@ def test_update_config(self, mock_yaml, mock_open):
mock_open.assert_called_with("config.yaml", "w")
mock_yaml.dump.assert_called_with(conf, mock.ANY, default_flow_style=False)

@mock.patch("media_downloader.process_messages")
@mock.patch("media_downloader.update_config")
@mock.patch("media_downloader.pyrogram.Client", new=MockClient)
@mock.patch("media_downloader.process_messages", new=mock_process_message)
def test_begin_import(self, mock_update_config):
@mock.patch("media_downloader.THIS_DIR", new=MOCK_DIR)
def test_begin_import(self, mock_update_config, mock_process_messages):
mock_process_messages.return_value = 5
result = self.loop.run_until_complete(async_begin_import(MOCK_CONF, 3))
conf = copy.deepcopy(MOCK_CONF)
conf["last_read_message_id"] = 5
self.assertDictEqual(result, conf)
mock_process_messages.assert_called_with(mock.ANY, mock.ANY, mock.ANY, mock.ANY, MOCK_DIR)

def test_process_message(self):
client = MockClient()
Expand Down Expand Up @@ -533,6 +566,7 @@ def test_process_message(self):
],
["voice", "photo"],
{"audio": ["all"], "voice": ["all"]},
MOCK_DIR
)
)
self.assertEqual(result, 1216)
Expand Down Expand Up @@ -574,6 +608,7 @@ def test_process_message_when_file_exists(self, mock_is_exist):
],
["voice", "photo"],
{"audio": ["all"], "voice": ["all"]},
MOCK_DIR
)
)
self.assertEqual(result, 1216)
Expand Down

0 comments on commit 20672ac

Please sign in to comment.