Skip to content

Commit

Permalink
Use the new format of options flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaskivskyi committed Jan 19, 2025
1 parent 3986902 commit d692f23
Showing 1 changed file with 81 additions and 38 deletions.
119 changes: 81 additions & 38 deletions custom_components/asusrouter/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,11 @@
import logging
import socket
from typing import Any
from urllib.parse import urlparse

import voluptuous as vol
from asusrouter import AsusData
from asusrouter.error import AsusRouterAccessError
from asusrouter.modules.endpoint.error import AccessError
from asusrouter.modules.homeassistant import convert_to_ha_sensors_group
from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow
from homeassistant.const import (
CONF_HOST,
CONF_NAME,
CONF_PASSWORD,
CONF_PORT,
CONF_SCAN_INTERVAL,
Expand All @@ -27,6 +21,11 @@
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.device_registry import format_mac

from asusrouter import AsusData
from asusrouter.error import AsusRouterAccessError
from asusrouter.modules.endpoint.error import AccessError
from asusrouter.modules.homeassistant import convert_to_ha_sensors_group

from .bridge import ARBridge
from .const import (
ACCESS_POINT,
Expand Down Expand Up @@ -120,7 +119,11 @@ def _check_errors(
if errors is None:
return False

if BASE in errors and errors[BASE] != RESULT_SUCCESS and errors[BASE] != "":
if (
BASE in errors
and errors[BASE] != RESULT_SUCCESS
and errors[BASE] != ""
):
return True

return False
Expand Down Expand Up @@ -148,7 +151,9 @@ async def _async_get_clients(
}
return result
except Exception as ex: # pylint: disable=broad-except
_LOGGER.warning("Cannot get clients for %s: %s", configs[CONF_HOST], ex)
_LOGGER.warning(
"Cannot get clients for %s: %s", configs[CONF_HOST], ex
)
return {}


Expand All @@ -172,7 +177,9 @@ async def _async_get_network_interfaces(
return labels
except Exception as ex: # pylint: disable=broad-except
_LOGGER.warning(
"Cannot get available network interfaces for %s: %s", configs[CONF_HOST], ex
"Cannot get available network interfaces for %s: %s",
configs[CONF_HOST],
ex,
)
return CONF_DEFAULT_INTERFACES

Expand Down Expand Up @@ -206,7 +213,9 @@ async def _async_check_connection(
args = ex.args
# Wrong credentials
if args[1] == AccessError.CREDENTIALS:
_LOGGER.error("Error during connection to `%s`. Wrong credentials", host)
_LOGGER.error(
"Error during connection to `%s`. Wrong credentials", host
)
return {
ERRORS: RESULT_WRONG_CREDENTIALS,
}
Expand Down Expand Up @@ -317,7 +326,9 @@ async def _async_process_step(
raise ValueError(f"Step `{step}` is not properly defined")
# If the next step is defined, move to it
if NEXT in description and description[NEXT]:
return await _async_process_step(steps, description[NEXT], redirect=True)
return await _async_process_step(
steps, description[NEXT], redirect=True
)
raise ValueError(f"Step `{step}` is not properly defined")
raise ValueError(f"Step `{step}` cannot be found")

Expand All @@ -334,7 +345,9 @@ def _create_form_find(
user_input = {}

schema = {
vol.Required(CONF_HOST, default=user_input.get(CONF_HOST, "")): cv.string,
vol.Required(
CONF_HOST, default=user_input.get(CONF_HOST, "")
): cv.string,
}

return vol.Schema(schema)
Expand All @@ -351,7 +364,8 @@ def _create_form_credentials(

schema = {
vol.Required(
CONF_USERNAME, default=user_input.get(CONF_USERNAME, CONF_DEFAULT_USERNAME)
CONF_USERNAME,
default=user_input.get(CONF_USERNAME, CONF_DEFAULT_USERNAME),
): cv.string,
vol.Required(
CONF_PASSWORD, default=user_input.get(CONF_PASSWORD, "")
Expand All @@ -377,12 +391,19 @@ def _create_form_operation(
user_input = {}

schema = {
vol.Required(CONF_MODE, default=user_input.get(CONF_MODE, mode)): vol.In(
{mode: CONF_LABELS_MODE.get(mode, mode) for mode in CONF_VALUES_MODE}
vol.Required(
CONF_MODE, default=user_input.get(CONF_MODE, mode)
): vol.In(
{
mode: CONF_LABELS_MODE.get(mode, mode)
for mode in CONF_VALUES_MODE
}
),
vol.Required(
CONF_SPLIT_INTERVALS,
default=user_input.get(CONF_SPLIT_INTERVALS, CONF_DEFAULT_SPLIT_INTERVALS),
default=user_input.get(
CONF_SPLIT_INTERVALS, CONF_DEFAULT_SPLIT_INTERVALS
),
): cv.boolean,
}

Expand All @@ -405,25 +426,37 @@ def _create_form_connected_devices(
schema = {
vol.Required(
CONF_TRACK_DEVICES,
default=user_input.get(CONF_TRACK_DEVICES, CONF_DEFAULT_TRACK_DEVICES),
default=user_input.get(
CONF_TRACK_DEVICES, CONF_DEFAULT_TRACK_DEVICES
),
): cv.boolean,
vol.Required(
CONF_CLIENT_DEVICE,
default=user_input.get(CONF_CLIENT_DEVICE, CONF_DEFAULT_CLIENT_DEVICE),
default=user_input.get(
CONF_CLIENT_DEVICE, CONF_DEFAULT_CLIENT_DEVICE
),
): cv.boolean,
vol.Required(
CONF_CLIENTS_IN_ATTR,
default=user_input.get(CONF_CLIENTS_IN_ATTR, CONF_DEFAULT_CLIENTS_IN_ATTR),
default=user_input.get(
CONF_CLIENTS_IN_ATTR, CONF_DEFAULT_CLIENTS_IN_ATTR
),
): cv.boolean,
vol.Required(
CONF_CLIENT_FILTER,
default=user_input.get(CONF_CLIENT_FILTER, CONF_DEFAULT_CLIENT_FILTER),
default=user_input.get(
CONF_CLIENT_FILTER, CONF_DEFAULT_CLIENT_FILTER
),
): vol.In(CONF_LABELS_CLIENT_FILTER),
vol.Optional(
CONF_CLIENT_FILTER_LIST,
default=default,
): cv.multi_select(
dict(sorted(user_input[ALL_CLIENTS].items(), key=lambda item: item[1]))
dict(
sorted(
user_input[ALL_CLIENTS].items(), key=lambda item: item[1]
)
)
),
vol.Required(
CONF_LATEST_CONNECTED,
Expand All @@ -433,7 +466,9 @@ def _create_form_connected_devices(
): cv.positive_int,
vol.Required(
CONF_INTERVAL_DEVICES,
default=user_input.get(CONF_INTERVAL_DEVICES, CONF_DEFAULT_SCAN_INTERVAL),
default=user_input.get(
CONF_INTERVAL_DEVICES, CONF_DEFAULT_SCAN_INTERVAL
),
): cv.positive_int,
}

Expand Down Expand Up @@ -481,7 +516,9 @@ def _create_form_intervals(
}

split = user_input.get(CONF_SPLIT_INTERVALS, CONF_DEFAULT_SPLIT_INTERVALS)
conf_scan_interval = user_input.get(CONF_SCAN_INTERVAL, CONF_DEFAULT_SCAN_INTERVAL)
conf_scan_interval = user_input.get(
CONF_SCAN_INTERVAL, CONF_DEFAULT_SCAN_INTERVAL
)

if split is False:
schema.update(
Expand All @@ -505,7 +542,8 @@ def _create_form_intervals(
vol.Required(
conf,
default=user_input.get(
conf, CONF_DEFAULT_INTERVALS.get(conf, conf_scan_interval)
conf,
CONF_DEFAULT_INTERVALS.get(conf, conf_scan_interval),
),
): cv.positive_int
for conf in CONF_INTERVALS
Expand Down Expand Up @@ -572,7 +610,9 @@ def _create_form_security(
schema = {
vol.Required(
CONF_HIDE_PASSWORDS,
default=user_input.get(CONF_HIDE_PASSWORDS, CONF_DEFAULT_HIDE_PASSWORDS),
default=user_input.get(
CONF_HIDE_PASSWORDS, CONF_DEFAULT_HIDE_PASSWORDS
),
): cv.boolean,
}

Expand Down Expand Up @@ -666,7 +706,9 @@ async def async_step_credentials(

if user_input:
# Check credentials and connection
result = await _async_check_connection(self.hass, self._configs, user_input)
result = await _async_check_connection(
self.hass, self._configs, user_input
)
# Show errors if any
if ERRORS in result:
errors[BASE] = result[ERRORS]
Expand Down Expand Up @@ -753,7 +795,9 @@ async def async_step_connected_devices(
)
return self.async_show_form(
step_id=step_id,
data_schema=_create_form_connected_devices(user_input, self._mode),
data_schema=_create_form_connected_devices(
user_input, self._mode
),
)
self._options.update(user_input)

Expand Down Expand Up @@ -859,35 +903,32 @@ async def async_step_finish(
@callback
def async_get_options_flow(
config_entry: ConfigEntry,
) -> OptionsFlow:
) -> AROptionsFlowHandler:
"""Get the options flow."""

return AROptionsFlowHandler(config_entry)
return AROptionsFlowHandler()


class AROptionsFlowHandler(OptionsFlow):
"""Options flow for AsusRouter."""

def __init__(
self,
config_entry: ConfigEntry,
) -> None:
"""Initialize options flow."""

self.config_entry = config_entry
async def async_step_init(
self,
user_input: dict[str, Any] | None = None,
) -> FlowResult:
"""Options flow."""

self._selection: dict[str, Any] = {}
self._configs: dict[str, Any] = self.config_entry.data.copy()
self._host: str = self._configs[CONF_HOST]
self._options: dict[str, Any] = self.config_entry.options.copy()
self._mode = self._options.get(CONF_MODE, CONF_DEFAULT_MODE)

async def async_step_init(
self,
user_input: dict[str, Any] | None = None,
) -> FlowResult:
"""Options flow."""

return await self.async_step_options(user_input)

async def async_step_options(
Expand Down Expand Up @@ -1051,7 +1092,9 @@ async def async_step_interfaces(
user_input[INTERFACES].append(interface)
return self.async_show_form(
step_id=step_id,
data_schema=_create_form_interfaces(user_input, default=selected),
data_schema=_create_form_interfaces(
user_input, default=selected
),
)

self._options.update(user_input)
Expand Down

0 comments on commit d692f23

Please sign in to comment.