From d692f237c10426c858b9cd402b5a1205045572ca Mon Sep 17 00:00:00 2001 From: Yevhenii Vaskivskyi Date: Sun, 19 Jan 2025 21:09:21 +0100 Subject: [PATCH] Use the new format of options flow --- custom_components/asusrouter/config_flow.py | 119 +++++++++++++------- 1 file changed, 81 insertions(+), 38 deletions(-) diff --git a/custom_components/asusrouter/config_flow.py b/custom_components/asusrouter/config_flow.py index 1ec2733..cfc9909 100644 --- a/custom_components/asusrouter/config_flow.py +++ b/custom_components/asusrouter/config_flow.py @@ -5,17 +5,11 @@ import logging import socket from typing import Any -from urllib.parse import urlparse import voluptuous as vol -from asusrouter import AsusData -from asusrouter.error import AsusRouterAccessError -from asusrouter.modules.endpoint.error import AccessError -from asusrouter.modules.homeassistant import convert_to_ha_sensors_group from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow from homeassistant.const import ( CONF_HOST, - CONF_NAME, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, @@ -27,6 +21,11 @@ from homeassistant.helpers import config_validation as cv from homeassistant.helpers.device_registry import format_mac +from asusrouter import AsusData +from asusrouter.error import AsusRouterAccessError +from asusrouter.modules.endpoint.error import AccessError +from asusrouter.modules.homeassistant import convert_to_ha_sensors_group + from .bridge import ARBridge from .const import ( ACCESS_POINT, @@ -120,7 +119,11 @@ def _check_errors( if errors is None: return False - if BASE in errors and errors[BASE] != RESULT_SUCCESS and errors[BASE] != "": + if ( + BASE in errors + and errors[BASE] != RESULT_SUCCESS + and errors[BASE] != "" + ): return True return False @@ -148,7 +151,9 @@ async def _async_get_clients( } return result except Exception as ex: # pylint: disable=broad-except - _LOGGER.warning("Cannot get clients for %s: %s", configs[CONF_HOST], ex) + _LOGGER.warning( + "Cannot get clients for %s: %s", configs[CONF_HOST], ex + ) return {} @@ -172,7 +177,9 @@ async def _async_get_network_interfaces( return labels except Exception as ex: # pylint: disable=broad-except _LOGGER.warning( - "Cannot get available network interfaces for %s: %s", configs[CONF_HOST], ex + "Cannot get available network interfaces for %s: %s", + configs[CONF_HOST], + ex, ) return CONF_DEFAULT_INTERFACES @@ -206,7 +213,9 @@ async def _async_check_connection( args = ex.args # Wrong credentials if args[1] == AccessError.CREDENTIALS: - _LOGGER.error("Error during connection to `%s`. Wrong credentials", host) + _LOGGER.error( + "Error during connection to `%s`. Wrong credentials", host + ) return { ERRORS: RESULT_WRONG_CREDENTIALS, } @@ -317,7 +326,9 @@ async def _async_process_step( raise ValueError(f"Step `{step}` is not properly defined") # If the next step is defined, move to it if NEXT in description and description[NEXT]: - return await _async_process_step(steps, description[NEXT], redirect=True) + return await _async_process_step( + steps, description[NEXT], redirect=True + ) raise ValueError(f"Step `{step}` is not properly defined") raise ValueError(f"Step `{step}` cannot be found") @@ -334,7 +345,9 @@ def _create_form_find( user_input = {} schema = { - vol.Required(CONF_HOST, default=user_input.get(CONF_HOST, "")): cv.string, + vol.Required( + CONF_HOST, default=user_input.get(CONF_HOST, "") + ): cv.string, } return vol.Schema(schema) @@ -351,7 +364,8 @@ def _create_form_credentials( schema = { vol.Required( - CONF_USERNAME, default=user_input.get(CONF_USERNAME, CONF_DEFAULT_USERNAME) + CONF_USERNAME, + default=user_input.get(CONF_USERNAME, CONF_DEFAULT_USERNAME), ): cv.string, vol.Required( CONF_PASSWORD, default=user_input.get(CONF_PASSWORD, "") @@ -377,12 +391,19 @@ def _create_form_operation( user_input = {} schema = { - vol.Required(CONF_MODE, default=user_input.get(CONF_MODE, mode)): vol.In( - {mode: CONF_LABELS_MODE.get(mode, mode) for mode in CONF_VALUES_MODE} + vol.Required( + CONF_MODE, default=user_input.get(CONF_MODE, mode) + ): vol.In( + { + mode: CONF_LABELS_MODE.get(mode, mode) + for mode in CONF_VALUES_MODE + } ), vol.Required( CONF_SPLIT_INTERVALS, - default=user_input.get(CONF_SPLIT_INTERVALS, CONF_DEFAULT_SPLIT_INTERVALS), + default=user_input.get( + CONF_SPLIT_INTERVALS, CONF_DEFAULT_SPLIT_INTERVALS + ), ): cv.boolean, } @@ -405,25 +426,37 @@ def _create_form_connected_devices( schema = { vol.Required( CONF_TRACK_DEVICES, - default=user_input.get(CONF_TRACK_DEVICES, CONF_DEFAULT_TRACK_DEVICES), + default=user_input.get( + CONF_TRACK_DEVICES, CONF_DEFAULT_TRACK_DEVICES + ), ): cv.boolean, vol.Required( CONF_CLIENT_DEVICE, - default=user_input.get(CONF_CLIENT_DEVICE, CONF_DEFAULT_CLIENT_DEVICE), + default=user_input.get( + CONF_CLIENT_DEVICE, CONF_DEFAULT_CLIENT_DEVICE + ), ): cv.boolean, vol.Required( CONF_CLIENTS_IN_ATTR, - default=user_input.get(CONF_CLIENTS_IN_ATTR, CONF_DEFAULT_CLIENTS_IN_ATTR), + default=user_input.get( + CONF_CLIENTS_IN_ATTR, CONF_DEFAULT_CLIENTS_IN_ATTR + ), ): cv.boolean, vol.Required( CONF_CLIENT_FILTER, - default=user_input.get(CONF_CLIENT_FILTER, CONF_DEFAULT_CLIENT_FILTER), + default=user_input.get( + CONF_CLIENT_FILTER, CONF_DEFAULT_CLIENT_FILTER + ), ): vol.In(CONF_LABELS_CLIENT_FILTER), vol.Optional( CONF_CLIENT_FILTER_LIST, default=default, ): cv.multi_select( - dict(sorted(user_input[ALL_CLIENTS].items(), key=lambda item: item[1])) + dict( + sorted( + user_input[ALL_CLIENTS].items(), key=lambda item: item[1] + ) + ) ), vol.Required( CONF_LATEST_CONNECTED, @@ -433,7 +466,9 @@ def _create_form_connected_devices( ): cv.positive_int, vol.Required( CONF_INTERVAL_DEVICES, - default=user_input.get(CONF_INTERVAL_DEVICES, CONF_DEFAULT_SCAN_INTERVAL), + default=user_input.get( + CONF_INTERVAL_DEVICES, CONF_DEFAULT_SCAN_INTERVAL + ), ): cv.positive_int, } @@ -481,7 +516,9 @@ def _create_form_intervals( } split = user_input.get(CONF_SPLIT_INTERVALS, CONF_DEFAULT_SPLIT_INTERVALS) - conf_scan_interval = user_input.get(CONF_SCAN_INTERVAL, CONF_DEFAULT_SCAN_INTERVAL) + conf_scan_interval = user_input.get( + CONF_SCAN_INTERVAL, CONF_DEFAULT_SCAN_INTERVAL + ) if split is False: schema.update( @@ -505,7 +542,8 @@ def _create_form_intervals( vol.Required( conf, default=user_input.get( - conf, CONF_DEFAULT_INTERVALS.get(conf, conf_scan_interval) + conf, + CONF_DEFAULT_INTERVALS.get(conf, conf_scan_interval), ), ): cv.positive_int for conf in CONF_INTERVALS @@ -572,7 +610,9 @@ def _create_form_security( schema = { vol.Required( CONF_HIDE_PASSWORDS, - default=user_input.get(CONF_HIDE_PASSWORDS, CONF_DEFAULT_HIDE_PASSWORDS), + default=user_input.get( + CONF_HIDE_PASSWORDS, CONF_DEFAULT_HIDE_PASSWORDS + ), ): cv.boolean, } @@ -666,7 +706,9 @@ async def async_step_credentials( if user_input: # Check credentials and connection - result = await _async_check_connection(self.hass, self._configs, user_input) + result = await _async_check_connection( + self.hass, self._configs, user_input + ) # Show errors if any if ERRORS in result: errors[BASE] = result[ERRORS] @@ -753,7 +795,9 @@ async def async_step_connected_devices( ) return self.async_show_form( step_id=step_id, - data_schema=_create_form_connected_devices(user_input, self._mode), + data_schema=_create_form_connected_devices( + user_input, self._mode + ), ) self._options.update(user_input) @@ -859,10 +903,10 @@ async def async_step_finish( @callback def async_get_options_flow( config_entry: ConfigEntry, - ) -> OptionsFlow: + ) -> AROptionsFlowHandler: """Get the options flow.""" - return AROptionsFlowHandler(config_entry) + return AROptionsFlowHandler() class AROptionsFlowHandler(OptionsFlow): @@ -870,11 +914,14 @@ class AROptionsFlowHandler(OptionsFlow): def __init__( self, - config_entry: ConfigEntry, ) -> None: """Initialize options flow.""" - self.config_entry = config_entry + async def async_step_init( + self, + user_input: dict[str, Any] | None = None, + ) -> FlowResult: + """Options flow.""" self._selection: dict[str, Any] = {} self._configs: dict[str, Any] = self.config_entry.data.copy() @@ -882,12 +929,6 @@ def __init__( self._options: dict[str, Any] = self.config_entry.options.copy() self._mode = self._options.get(CONF_MODE, CONF_DEFAULT_MODE) - async def async_step_init( - self, - user_input: dict[str, Any] | None = None, - ) -> FlowResult: - """Options flow.""" - return await self.async_step_options(user_input) async def async_step_options( @@ -1051,7 +1092,9 @@ async def async_step_interfaces( user_input[INTERFACES].append(interface) return self.async_show_form( step_id=step_id, - data_schema=_create_form_interfaces(user_input, default=selected), + data_schema=_create_form_interfaces( + user_input, default=selected + ), ) self._options.update(user_input)