diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a8a7c38e4..e558a1444 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,7 +50,8 @@ jobs: - name: Run Forge tests run: | - forge test -v --fork-url https://gateway.tenderly.co/public/sepolia --fork-block-number 5196000 + forge test -vvv --fork-url https://gateway.tenderly.co/public/sepolia --fork-block-number 5196000 + id: forge-test - name: Run solhint run: npx solhint contracts/**/*.sol diff --git a/contracts/AccessController.sol b/contracts/AccessController.sol index 0fb38ac3f..9587bb859 100644 --- a/contracts/AccessController.sol +++ b/contracts/AccessController.sol @@ -9,7 +9,8 @@ import { IPAccountChecker } from "./lib/registries/IPAccountChecker.sol"; import { IIPAccount } from "./interfaces/IIPAccount.sol"; import { AccessPermission } from "./lib/AccessPermission.sol"; import { Errors } from "./lib/Errors.sol"; -import { Governable } from "./governance/Governable.sol"; +import { GovernableUpgradeable } from "./governance/GovernableUpgradeable.sol"; +import { UUPSUpgradeable } from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; /// @title AccessController /// @dev This contract is used to control access permissions for different function calls in the protocol. @@ -26,29 +27,45 @@ import { Governable } from "./governance/Governable.sol"; /// - setPermission: Sets the permission for a specific function call. /// - getPermission: Returns the permission level for a specific function call. /// - checkPermission: Checks if a specific function call is allowed. -contract AccessController is IAccessController, Governable { +contract AccessController is IAccessController, GovernableUpgradeable, UUPSUpgradeable { using IPAccountChecker for IIPAccountRegistry; - address public IP_ACCOUNT_REGISTRY; - address public MODULE_REGISTRY; - - /// @dev Tracks the permission granted to an encoded permission path, where the + /// @dev The storage struct of AccessController. + /// @param encodedPermissions tracks the permission granted to an encoded permission path, where the /// encoded permission path = keccak256(abi.encodePacked(ipAccount, signer, to, func)) - mapping(bytes32 => uint8) internal encodedPermissions; + /// @notice The address of the IP Account Registry. + /// @notice The address of the Module Registry. + /// @custom:storage-location erc7201:story-protocol.AccessController + struct AccessControllerStorage { + mapping(bytes32 => uint8) encodedPermissions; + address ipAccountRegistry; + address moduleRegistry; + } - constructor(address governance) Governable(governance) {} + // keccak256(abi.encode(uint256(keccak256("story-protocol.AccessController")) - 1)) & ~bytes32(uint256(0xff)); + bytes32 private constant AccessControllerStorageLocation = + 0xe80df7f3a04d1e1a0b61a4a820184d4b4a2f8a6a808f315dbcc7b502f40b1800; - // TODO: Change the function name to not clash with potential proxy contract `initialize`. - // TODO: Only allow calling once. - /// @dev Initialize the Access Controller with the IP Account Registry and Module Registry addresses. - /// These are separated from the constructor, because we need to deploy the AccessController first for - /// to deploy many registry and module contracts, including the IP Account Registry and Module Registry. - /// @dev Enforced to be only callable by the protocol admin in governance. - /// @param ipAccountRegistry The address of the IP Account Registry. - /// @param moduleRegistry The address of the Module Registry. - function initialize(address ipAccountRegistry, address moduleRegistry) external onlyProtocolAdmin { - IP_ACCOUNT_REGISTRY = ipAccountRegistry; - MODULE_REGISTRY = moduleRegistry; + /// Constructor + /// @custom:oz-upgrades-unsafe-allow constructor + constructor() { + _disableInitializers(); + } + + /// @notice Initializes implementation contract + /// @param governance The address of the governance contract + function initialize(address governance) external initializer { + __GovernableUpgradeable_init(governance); + } + + /// @notice Sets the addresses of the IP account registry and the module registry + /// @dev TODO: figure out how to set these addresses in the constructor to make them immutable + /// @param ipAccountRegistry address of the IP account registry + /// @param moduleRegistry address of the module registry + function setAddresses(address ipAccountRegistry, address moduleRegistry) external onlyProtocolAdmin { + AccessControllerStorage storage $ = _getAccessControllerStorage(); + $.ipAccountRegistry = ipAccountRegistry; + $.moduleRegistry = moduleRegistry; } /// @notice Sets a batch of permissions in a single transaction. @@ -115,14 +132,15 @@ contract AccessController is IAccessController, Governable { if (signer == address(0)) { revert Errors.AccessController__SignerIsZeroAddress(); } - if (!IIPAccountRegistry(IP_ACCOUNT_REGISTRY).isIpAccount(ipAccount)) { + AccessControllerStorage storage $ = _getAccessControllerStorage(); + if (!IIPAccountRegistry($.ipAccountRegistry).isIpAccount(ipAccount)) { revert Errors.AccessController__IPAccountIsNotValid(ipAccount); } // permission must be one of ABSTAIN, ALLOW, DENY if (permission > 2) { revert Errors.AccessController__PermissionIsNotValid(); } - if (!IModuleRegistry(MODULE_REGISTRY).isRegistered(msg.sender) && ipAccount != msg.sender) { + if (!IModuleRegistry($.moduleRegistry).isRegistered(msg.sender) && ipAccount != msg.sender) { revert Errors.AccessController__CallerIsNotIPAccount(); } _setPermission(ipAccount, signer, to, func, permission); @@ -144,15 +162,16 @@ contract AccessController is IAccessController, Governable { // The ipAccount is restricted to interact exclusively with registered modules. // This includes initiating calls to these modules and receiving calls from them. // Additionally, it can modify Permissions settings. + AccessControllerStorage storage $ = _getAccessControllerStorage(); if ( to != address(this) && - !IModuleRegistry(MODULE_REGISTRY).isRegistered(to) && - !IModuleRegistry(MODULE_REGISTRY).isRegistered(signer) + !IModuleRegistry($.moduleRegistry).isRegistered(to) && + !IModuleRegistry($.moduleRegistry).isRegistered(signer) ) { revert Errors.AccessController__BothCallerAndRecipientAreNotRegisteredModule(signer, to); } // Must be a valid IPAccount - if (!IIPAccountRegistry(IP_ACCOUNT_REGISTRY).isIpAccount(ipAccount)) { + if (!IIPAccountRegistry($.ipAccountRegistry).isIpAccount(ipAccount)) { revert Errors.AccessController__IPAccountIsNotValid(ipAccount); } // Owner can call all functions of all modules @@ -196,12 +215,14 @@ contract AccessController is IAccessController, Governable { /// @param func The function selector of `to` that can be called by the `signer` on behalf of the `ipAccount` /// @return permission The current permission level for the function call on `to` by the `signer` for `ipAccount` function getPermission(address ipAccount, address signer, address to, bytes4 func) public view returns (uint8) { - return encodedPermissions[_encodePermission(ipAccount, signer, to, func)]; + AccessControllerStorage storage $ = _getAccessControllerStorage(); + return $.encodedPermissions[_encodePermission(ipAccount, signer, to, func)]; } /// @dev The permission parameters will be encoded into bytes32 as key in the permissions mapping to save storage function _setPermission(address ipAccount, address signer, address to, bytes4 func, uint8 permission) internal { - encodedPermissions[_encodePermission(ipAccount, signer, to, func)] = permission; + AccessControllerStorage storage $ = _getAccessControllerStorage(); + $.encodedPermissions[_encodePermission(ipAccount, signer, to, func)] = permission; } /// @dev encode permission to hash (bytes32) @@ -216,4 +237,15 @@ contract AccessController is IAccessController, Governable { } return keccak256(abi.encode(IIPAccount(payable(ipAccount)).owner(), ipAccount, signer, to, func)); } + + /// @dev Returns the storage struct of AccessController. + function _getAccessControllerStorage() private pure returns (AccessControllerStorage storage $) { + assembly { + $.slot := AccessControllerStorageLocation + } + } + + /// @dev Hook to authorize the upgrade according to UUPSUgradeable + /// @param newImplementation The address of the new implementation + function _authorizeUpgrade(address newImplementation) internal override onlyProtocolAdmin {} } diff --git a/contracts/access/AccessControlled.sol b/contracts/access/AccessControlled.sol index 28b8af49c..cbacc02d3 100644 --- a/contracts/access/AccessControlled.sol +++ b/contracts/access/AccessControlled.sol @@ -17,20 +17,12 @@ abstract contract AccessControlled { using IPAccountChecker for IIPAccountRegistry; /// @notice The IAccessController instance for permission checks. + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable IAccessController public immutable ACCESS_CONTROLLER; /// @notice The IIPAccountRegistry instance for IP account verification. + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable IIPAccountRegistry public immutable IP_ACCOUNT_REGISTRY; - /// @dev Initializes the contract by setting the ACCESS_CONTROLLER and IP_ACCOUNT_REGISTRY addresses. - /// @param accessController The address of the AccessController contract. - /// @param ipAccountRegistry The address of the IPAccountRegistry contract. - constructor(address accessController, address ipAccountRegistry) { - if (accessController == address(0)) revert Errors.AccessControlled__ZeroAddress(); - if (ipAccountRegistry == address(0)) revert Errors.AccessControlled__ZeroAddress(); - ACCESS_CONTROLLER = IAccessController(accessController); - IP_ACCOUNT_REGISTRY = IIPAccountRegistry(ipAccountRegistry); - } - /// @notice Verifies that the caller has the necessary permission for the given IPAccount. /// @dev Modifier that calls _verifyPermission to check if the provided IP account has the required permission. /// modules can use this modifier to check if the caller has the necessary permission. @@ -51,6 +43,17 @@ abstract contract AccessControlled { _; } + /// @dev Constructor contract by setting the ACCESS_CONTROLLER and IP_ACCOUNT_REGISTRY addresses. + /// @param accessController The address of the AccessController contract. + /// @param ipAccountRegistry The address of the IPAccountRegistry contract. + /// @custom:oz-upgrades-unsafe-allow constructor + constructor(address accessController, address ipAccountRegistry) { + if (accessController == address(0)) revert Errors.AccessControlled__ZeroAddress(); + if (ipAccountRegistry == address(0)) revert Errors.AccessControlled__ZeroAddress(); + ACCESS_CONTROLLER = IAccessController(accessController); + IP_ACCOUNT_REGISTRY = IIPAccountRegistry(ipAccountRegistry); + } + /// @dev Internal function to verify if the caller (msg.sender) has the required permission to execute /// the function on provided ipAccount. /// @param ipAccount The address of the IP account to verify. diff --git a/contracts/governance/GovernableUpgradeable.sol b/contracts/governance/GovernableUpgradeable.sol index 56f214e99..b2c0fa2e1 100644 --- a/contracts/governance/GovernableUpgradeable.sol +++ b/contracts/governance/GovernableUpgradeable.sol @@ -12,8 +12,9 @@ import { GovernanceLib } from "../lib/GovernanceLib.sol"; /// @title Governable /// @dev All contracts managed by governance should inherit from this contract. abstract contract GovernableUpgradeable is IGovernable, Initializable { - /// @custom:storage-location erc7201:story-protocol.GovernableUpgradeable + /// @dev Storage for GovernableUpgradeable /// @param governance The address of the governance. + /// @custom:storage-location erc7201:story-protocol.GovernableUpgradeable struct GovernableUpgradeableStorage { address governance; } @@ -44,7 +45,7 @@ abstract contract GovernableUpgradeable is IGovernable, Initializable { _disableInitializers(); } - function __GovernableUpgradeable_init(address governance_) internal { + function __GovernableUpgradeable_init(address governance_) internal onlyInitializing { if (governance_ == address(0)) revert Errors.Governance__ZeroAddress(); _getGovernableUpgradeableStorage().governance = governance_; emit GovernanceUpdated(governance_); diff --git a/contracts/lib/Errors.sol b/contracts/lib/Errors.sol index cbadb828a..0fd4a5e38 100644 --- a/contracts/lib/Errors.sol +++ b/contracts/lib/Errors.sol @@ -171,10 +171,10 @@ library Errors { error LicensingModule__LinkingRevokedLicense(); //////////////////////////////////////////////////////////////////////////// - // LicensingModuleAware // + // BasePolicyFrameworkManager // //////////////////////////////////////////////////////////////////////////// - error LicensingModuleAware__CallerNotLicensingModule(); + error BasePolicyFrameworkManager__CallerNotLicensingModule(); //////////////////////////////////////////////////////////////////////////// // PolicyFrameworkManager // diff --git a/contracts/lib/registries/IPAccountChecker.sol b/contracts/lib/registries/IPAccountChecker.sol index c06a24ce5..b2c2cd528 100644 --- a/contracts/lib/registries/IPAccountChecker.sol +++ b/contracts/lib/registries/IPAccountChecker.sol @@ -23,7 +23,7 @@ library IPAccountChecker { uint256 chainId_, address tokenContract_, uint256 tokenId_ - ) external view returns (bool) { + ) internal view returns (bool) { return ipAccountRegistry_.ipAccount(chainId_, tokenContract_, tokenId_).code.length != 0; } @@ -34,7 +34,7 @@ library IPAccountChecker { function isIpAccount( IIPAccountRegistry ipAccountRegistry_, address ipAccountAddress_ - ) external view returns (bool) { + ) internal view returns (bool) { if (ipAccountAddress_ == address(0)) return false; if (ipAccountAddress_.code.length == 0) return false; if (!ERC165Checker.supportsERC165(ipAccountAddress_)) return false; diff --git a/contracts/modules/licensing/BasePolicyFrameworkManager.sol b/contracts/modules/licensing/BasePolicyFrameworkManager.sol index 3b72d22e4..3ae4fe3bf 100644 --- a/contracts/modules/licensing/BasePolicyFrameworkManager.sol +++ b/contracts/modules/licensing/BasePolicyFrameworkManager.sol @@ -4,26 +4,82 @@ pragma solidity 0.8.23; // external import { ERC165 } from "@openzeppelin/contracts/utils/introspection/ERC165.sol"; import { IERC165 } from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; +import { Initializable } from "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; // contracts import { IPolicyFrameworkManager } from "../../interfaces/modules/licensing/IPolicyFrameworkManager.sol"; -import { LicensingModuleAware } from "../../modules/licensing/LicensingModuleAware.sol"; +import { ILicensingModule } from "../../interfaces/modules/licensing/ILicensingModule.sol"; +import { Errors } from "../../lib/Errors.sol"; /// @title BasePolicyFrameworkManager +/// TODO: If we want to open this, we need an upgradeable and non-upgradeable Base version, or just promote +/// the IPolicyFrameworkManager in the docs. /// @notice Base contract for policy framework managers. -abstract contract BasePolicyFrameworkManager is IPolicyFrameworkManager, ERC165, LicensingModuleAware { - /// @notice Returns the name to be show in license NFT (LNFT) metadata - string public override name; +abstract contract BasePolicyFrameworkManager is IPolicyFrameworkManager, ERC165, Initializable { + /// @dev Storage for BasePolicyFrameworkManager + /// @param name The name of the policy framework manager + /// @param licenseTextUrl The URL to the off chain legal agreement template text + /// @custom:storage-location erc7201:story-protocol.BasePolicyFrameworkManager + struct BasePolicyFrameworkManagerStorage { + string name; + string licenseTextUrl; + } - /// @notice Returns the URL to the off chain legal agreement template text - string public override licenseTextUrl; + // keccak256(abi.encode(uint256(keccak256("story-protocol.BasePolicyFrameworkManager")) - 1)) + // & ~bytes32(uint256(0xff)); + bytes32 private constant BasePolicyFrameworkManagerStorageLocation = + 0xa55803740ac9329334ad7b6cde0ec056cc3ba32125b59c579552512bed001f00; + + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable + ILicensingModule public immutable LICENSING_MODULE; + + /// @notice Modifier for authorizing the calling entity to only the LicensingModule. + modifier onlyLicensingModule() { + if (msg.sender != address(LICENSING_MODULE)) { + revert Errors.BasePolicyFrameworkManager__CallerNotLicensingModule(); + } + _; + } + + /// @notice Constructor function + /// @param licensing The address of the LicensingModule + /// @custom:oz-upgrades-unsafe-allow constructor + constructor(address licensing) { + LICENSING_MODULE = ILicensingModule(licensing); + } - constructor(address licensing, string memory name_, string memory licenseTextUrl_) LicensingModuleAware(licensing) { - name = name_; - licenseTextUrl = licenseTextUrl_; + /// @notice Initializes the BasePolicyFrameworkManager contract as per the Initializable contract. + /// @param _name The name of the policy framework manager + /// @param _licenseTextUrl The URL to the off chain legal agreement template text + function __BasePolicyFrameworkManager_init( + string memory _name, + string memory _licenseTextUrl + ) internal onlyInitializing { + _getBasePolicyFrameworkManagerStorage().name = _name; + _getBasePolicyFrameworkManagerStorage().licenseTextUrl = _licenseTextUrl; + } + + /// @notice Returns the name of the policy framework manager + function name() public view override returns (string memory) { + return _getBasePolicyFrameworkManagerStorage().name; + } + + /// @notice Returns the URL to the off chain legal agreement template text + function licenseTextUrl() public view override returns (string memory) { + return _getBasePolicyFrameworkManagerStorage().licenseTextUrl; } /// @notice IERC165 interface support. function supportsInterface(bytes4 interfaceId) public view virtual override(ERC165, IERC165) returns (bool) { return interfaceId == type(IPolicyFrameworkManager).interfaceId || super.supportsInterface(interfaceId); } + + function _getBasePolicyFrameworkManagerStorage() + internal + pure + returns (BasePolicyFrameworkManagerStorage storage $) + { + assembly { + $.slot := BasePolicyFrameworkManagerStorageLocation + } + } } diff --git a/contracts/modules/licensing/LicensingModule.sol b/contracts/modules/licensing/LicensingModule.sol index 75574f30b..191f20729 100644 --- a/contracts/modules/licensing/LicensingModule.sol +++ b/contracts/modules/licensing/LicensingModule.sol @@ -5,7 +5,8 @@ pragma solidity 0.8.23; import { EnumerableSet } from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; import { ERC165Checker } from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol"; -import { ReentrancyGuard } from "@openzeppelin/contracts/utils/ReentrancyGuard.sol"; +import { ReentrancyGuardUpgradeable } from "@openzeppelin/contracts-upgradeable/utils/ReentrancyGuardUpgradeable.sol"; +import { UUPSUpgradeable } from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; import { IIPAccount } from "../../interfaces/IIPAccount.sol"; import { IPolicyFrameworkManager } from "../../interfaces/modules/licensing/IPolicyFrameworkManager.sol"; @@ -22,6 +23,7 @@ import { RoyaltyModule } from "../../modules/royalty/RoyaltyModule.sol"; import { AccessControlled } from "../../access/AccessControlled.sol"; import { LICENSING_MODULE_KEY } from "../../lib/modules/Module.sol"; import { BaseModule } from "../BaseModule.sol"; +import { GovernableUpgradeable } from "../../governance/GovernableUpgradeable.sol"; // TODO: consider disabling operators/approvals on creation /// @title Licensing Module @@ -32,7 +34,14 @@ import { BaseModule } from "../BaseModule.sol"; /// - Linking IP to its parent /// - Verifying linking parameters /// - Verifying policy parameters -contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, ReentrancyGuard { +contract LicensingModule is + AccessControlled, + ILicensingModule, + BaseModule, + ReentrancyGuardUpgradeable, + GovernableUpgradeable, + UUPSUpgradeable +{ using ERC165Checker for address; using IPAccountChecker for IIPAccountRegistry; using EnumerableSet for EnumerableSet.UintSet; @@ -40,42 +49,47 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen using Licensing for *; using Strings for *; + /// @notice Storage struct for the LicensingModule + /// @param registeredFrameworkManagers Mapping of registered policy framework managers + /// @param hashedPolicies Mapping of policy data to policy id + /// @param policies Mapping of policy id to policy data (hashed) + /// @param totalPolicies Total amount of distinct licensing policies in LicenseRegistry + /// @param policySetups Internal mapping to track if a policy was set by linking or minting, + /// and the index of the policy in the + /// ipId policy set. Policies can't be removed, but they can be deactivated by setting active to false. + /// @param policiesPerIpId the set of policy ids attached to the given ipId + /// @param ipIdParents Mapping of parent policy ids for the given ipId + /// @param ipRights Mapping of policy aggregator data for the given ipId in a framework + /// @custom:storage-location erc7201:story-protocol.LicensingModule + struct LicensingModuleStorage { + mapping(address framework => bool registered) registeredFrameworkManagers; + mapping(bytes32 policyHash => uint256 policyId) hashedPolicies; + mapping(uint256 policyId => Licensing.Policy policyData) policies; + uint256 totalPolicies; + mapping(address ipId => mapping(uint256 policyId => PolicySetup setup)) policySetups; + mapping(bytes32 hashIpIdAnInherited => EnumerableSet.UintSet policyIds) policiesPerIpId; + mapping(address ipId => EnumerableSet.AddressSet parentIpIds) ipIdParents; + mapping(address framework => mapping(address ipId => bytes policyAggregatorData)) ipRights; + } + /// @inheritdoc IModule string public constant override name = LICENSING_MODULE_KEY; /// @notice Returns the canonical protocol-wide RoyaltyModule + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable RoyaltyModule public immutable ROYALTY_MODULE; /// @notice Returns the canonical protocol-wide LicenseRegistry + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable ILicenseRegistry public immutable LICENSE_REGISTRY; /// @notice Returns the dispute module + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable IDisputeModule public immutable DISPUTE_MODULE; - /// @dev Returns if a framework is registered or not - mapping(address framework => bool registered) private _registeredFrameworkManagers; - - /// @dev Returns the policy id for the given policy data (hashed) - mapping(bytes32 policyHash => uint256 policyId) private _hashedPolicies; - - /// @dev Returns the policy data for the given policy id - mapping(uint256 policyId => Licensing.Policy policyData) private _policies; - - /// @dev Total amount of distinct licensing policies in LicenseRegistry - uint256 private _totalPolicies; - - /// @dev Internal mapping to track if a policy was set by linking or minting, and the index of the policy in the - /// ipId policy set. Policies can't be removed, but they can be deactivated by setting active to false. - mapping(address ipId => mapping(uint256 policyId => PolicySetup setup)) private _policySetups; - - /// @dev Returns the set of policy ids attached to the given ipId - mapping(bytes32 hashIpIdAnInherited => EnumerableSet.UintSet policyIds) private _policiesPerIpId; - - /// @dev Returns the set of parent policy ids for the given ipId - mapping(address ipId => EnumerableSet.AddressSet parentIpIds) private _ipIdParents; - - /// @dev Returns the policy aggregator data for the given ipId in a framework - mapping(address framework => mapping(address ipId => bytes policyAggregatorData)) private _ipRights; + // keccak256(abi.encode(uint256(keccak256("story-protocol.LicensingModule")) - 1)) & ~bytes32(uint256(0xff)); + bytes32 private constant LicensingModuleStorageLocation = + 0x0f7178cb62e4803c52d40f70c08a6f88d6ee1af1838d58e0c83a222a6c3d3100; /// @notice Modifier to allow only LicenseRegistry as the caller modifier onlyLicenseRegistry() { @@ -83,6 +97,13 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen _; } + /// Constructor + /// @param accessController The address of the AccessController contract + /// @param ipAccountRegistry The address of the IPAccountRegistry contract + /// @param royaltyModule The address of the RoyaltyModule contract + /// @param registry The address of the LicenseRegistry contract + /// @param disputeModule The address of the DisputeModule contract + /// @custom:oz-upgrades-unsafe-allow constructor constructor( address accessController, address ipAccountRegistry, @@ -93,6 +114,13 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen ROYALTY_MODULE = RoyaltyModule(royaltyModule); LICENSE_REGISTRY = ILicenseRegistry(registry); DISPUTE_MODULE = IDisputeModule(disputeModule); + _disableInitializers(); + } + + function initialize(address governance) public initializer { + __ReentrancyGuard_init(); + __UUPSUpgradeable_init(); + __GovernableUpgradeable_init(governance); } /// @notice Registers a policy framework manager into the contract, so it can add policy data for licenses. @@ -106,7 +134,8 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen if (bytes(licenseUrl).length == 0 || licenseUrl.equal("")) { revert Errors.LicensingModule__EmptyLicenseUrl(); } - _registeredFrameworkManagers[manager] = true; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + $.registeredFrameworkManagers[manager] = true; emit PolicyFrameworkRegistered(manager, fwManager.name(), licenseUrl); } @@ -126,14 +155,15 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen if (pol.mintingFee > 0 && !ROYALTY_MODULE.isWhitelistedRoyaltyToken(pol.mintingFeeToken)) { revert Errors.LicensingModule__MintingFeeTokenNotWhitelisted(); } + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); (uint256 polId, bool newPol) = DataUniqueness.addIdOrGetExisting( abi.encode(pol), - _hashedPolicies, - _totalPolicies + $.hashedPolicies, + $.totalPolicies ); if (newPol) { - _totalPolicies = polId; - _policies[polId] = pol; + $.totalPolicies = polId; + $.policies[polId] = pol; emit PolicyRegistered( polId, pol.policyFramework, @@ -183,7 +213,8 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen address receiver, bytes calldata royaltyContext ) external nonReentrant returns (uint256 licenseId) { - _verifyPolicy(_policies[policyId]); + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + _verifyPolicy($.policies[policyId]); if (!IP_ACCOUNT_REGISTRY.isIpAccount(licensorIpId)) { revert Errors.LicensingModule__LicensorNotRegistered(); } @@ -195,7 +226,7 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen } _verifyIpNotDisputed(licensorIpId); - bool isInherited = _policySetups[licensorIpId][policyId].isInherited; + bool isInherited = $.policySetups[licensorIpId][policyId].isInherited; Licensing.Policy memory pol = policy(policyId); IPolicyFrameworkManager pfm = IPolicyFrameworkManager(pol.policyFramework); @@ -314,20 +345,23 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen /// @param policyFramework The address of the policy framework manager /// @return isRegistered True if the framework is registered function isFrameworkRegistered(address policyFramework) external view returns (bool) { - return _registeredFrameworkManagers[policyFramework]; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.registeredFrameworkManagers[policyFramework]; } /// @notice Returns amount of distinct licensing policies in the LicensingModule. /// @return totalPolicies The amount of policies function totalPolicies() external view returns (uint256) { - return _totalPolicies; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.totalPolicies; } /// @notice Returns the policy data for policyId, reverts if not found. /// @param policyId The id of the policy /// @return pol The policy data function policy(uint256 policyId) public view returns (Licensing.Policy memory pol) { - pol = _policies[policyId]; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + pol = $.policies[policyId]; _verifyPolicy(pol); return pol; } @@ -336,7 +370,8 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen /// @param pol The policy data in Policy struct /// @return policyId The id of the policy function getPolicyId(Licensing.Policy calldata pol) external view returns (uint256 policyId) { - return _hashedPolicies[keccak256(abi.encode(pol))]; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.hashedPolicies[keccak256(abi.encode(pol))]; } /// @notice Returns the policy aggregator data for the given IP ID in the framework. @@ -344,14 +379,16 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen /// @param ipId The id of the IP asset /// @return data The encoded policy aggregator data to be decoded by the framework manager function policyAggregatorData(address framework, address ipId) external view returns (bytes memory) { - return _ipRights[framework][ipId]; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.ipRights[framework][ipId]; } /// @notice Returns if policyId exists in the LicensingModule /// @param policyId The id of the policy /// @return isDefined True if the policy is defined function isPolicyDefined(uint256 policyId) public view returns (bool) { - return _policies[policyId].policyFramework != address(0); + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.policies[policyId].policyFramework != address(0); } /// @notice Returns the policy ids attached to an IP @@ -402,7 +439,8 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen address ipId, uint256 index ) external view returns (Licensing.Policy memory) { - return _policies[_policySetPerIpId(isInherited, ipId).at(index)]; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.policies[_policySetPerIpId(isInherited, ipId).at(index)]; } /// @notice Returns the status of a policy in an IP's policy set @@ -415,7 +453,8 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen address ipId, uint256 policyId ) external view returns (uint256 index, bool isInherited, bool active) { - PolicySetup storage setup = _policySetups[ipId][policyId]; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + PolicySetup storage setup = $.policySetups[ipId][policyId]; return (setup.index, setup.isInherited, setup.active); } @@ -424,7 +463,8 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen /// @param policyId The id of the policy to check if inherited /// @return isInherited True if the policy is inherited from a parent IP function isPolicyInherited(address ipId, uint256 policyId) external view returns (bool) { - return _policySetups[ipId][policyId].isInherited; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.policySetups[ipId][policyId].isInherited; } /// @notice Returns if an IP is a derivative of another IP @@ -432,26 +472,30 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen /// @param childIpId The id of the child IP asset to check /// @return isParent True if the child IP is a derivative of the parent IP function isParent(address parentIpId, address childIpId) external view returns (bool) { - return _ipIdParents[childIpId].contains(parentIpId); + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.ipIdParents[childIpId].contains(parentIpId); } /// @notice Returns the list of parent IP assets for a given child IP asset /// @param ipId The id of the child IP asset to check /// @return parentIpIds The ids of the parent IP assets function parentIpIds(address ipId) external view returns (address[] memory) { - return _ipIdParents[ipId].values(); + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.ipIdParents[ipId].values(); } /// @notice Returns the total number of parents for an IP asset /// @param ipId The id of the IP asset to check /// @return totalParents The total number of parent IP assets function totalParentsForIpId(address ipId) external view returns (uint256) { - return _ipIdParents[ipId].length(); + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.ipIdParents[ipId].length(); } /// @dev Verifies that the framework is registered in the LicensingModule function _verifyRegisteredFramework(address policyFramework) private view { - if (!_registeredFrameworkManagers[policyFramework]) { + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + if (!$.registeredFrameworkManagers[policyFramework]) { revert Errors.LicensingModule__FrameworkNotFound(); } } @@ -464,16 +508,17 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen bool skipIfDuplicate ) private returns (uint256 index) { _verifyCanAddPolicy(policyId, ipId, isInherited); + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); // Try and add the policy into the set. EnumerableSet.UintSet storage _pols = _policySetPerIpId(isInherited, ipId); if (!_pols.add(policyId)) { if (skipIfDuplicate) { - return _policySetups[ipId][policyId].index; + return $.policySetups[ipId][policyId].index; } revert Errors.LicensingModule__PolicyAlreadySetForIpId(); } index = _pols.length() - 1; - PolicySetup storage setup = _policySetups[ipId][policyId]; + PolicySetup storage setup = $.policySetups[ipId][policyId]; // This should not happen, but just in case if (setup.isSet) { revert Errors.LicensingModule__PolicyAlreadySetForIpId(); @@ -516,9 +561,10 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen // Add the policy of licenseIds[i] to the child. If the policy's already set from previous parents, // then the addition will be skipped. _addPolicyIdToIp({ ipId: childIpId, policyId: policyId, isInherited: true, skipIfDuplicate: true }); + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); // Set parent. We ignore the return value, since there are some cases where the same licensor gives the child // a License with another policy. - _ipIdParents[childIpId].add(licensor); + $.ipIdParents[childIpId].add(licensor); } /// @dev Verifies if the policyId can be added to the IP @@ -535,17 +581,18 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen // Owner of derivative is trying to set policies revert Errors.LicensingModule__DerivativesCannotAddPolicy(); } + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); // If we are here, this is a multiparent derivative // Checking for policy compatibility IPolicyFrameworkManager polManager = IPolicyFrameworkManager(policy(policyId).policyFramework); - Licensing.Policy memory pol = _policies[policyId]; + Licensing.Policy memory pol = $.policies[policyId]; (bool aggregatorChanged, bytes memory newAggregator) = polManager.processInheritedPolicies( - _ipRights[pol.policyFramework][ipId], + $.ipRights[pol.policyFramework][ipId], policyId, pol.frameworkData ); if (aggregatorChanged) { - _ipRights[pol.policyFramework][ipId] = newAggregator; + $.ipRights[pol.policyFramework][ipId] = newAggregator; } } @@ -558,7 +605,8 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen /// @dev Returns the policy set for the given ipId function _policySetPerIpId(bool isInherited, address ipId) private view returns (EnumerableSet.UintSet storage) { - return _policiesPerIpId[keccak256(abi.encode(isInherited, ipId))]; + LicensingModuleStorage storage $ = _getLicensingModuleStorage(); + return $.policiesPerIpId[keccak256(abi.encode(isInherited, ipId))]; } /// @dev Verifies if the IP is disputed @@ -568,4 +616,15 @@ contract LicensingModule is AccessControlled, ILicensingModule, BaseModule, Reen revert Errors.LicensingModule__DisputedIpId(); } } + + /// @dev Returns the storage struct of LicensingModule. + function _getLicensingModuleStorage() private pure returns (LicensingModuleStorage storage $) { + assembly { + $.slot := LicensingModuleStorageLocation + } + } + + /// @dev Hook to authorize the upgrade according to UUPSUgradeable + /// @param newImplementation The address of the new implementation + function _authorizeUpgrade(address newImplementation) internal override onlyProtocolAdmin {} } diff --git a/contracts/modules/licensing/LicensingModuleAware.sol b/contracts/modules/licensing/LicensingModuleAware.sol deleted file mode 100644 index 68e55fe3f..000000000 --- a/contracts/modules/licensing/LicensingModuleAware.sol +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-License-Identifier: BUSL-1.1 -pragma solidity 0.8.23; - -// contracts -import { ILicensingModule } from "../../interfaces/modules/licensing/ILicensingModule.sol"; -import { Errors } from "../../lib/Errors.sol"; - -/// @title LicensingModuleAware -/// @notice Base contract to be inherited by modules that need to access the licensing module. -abstract contract LicensingModuleAware { - /// @notice Returns the protocol-wide licensing module. - ILicensingModule public immutable LICENSING_MODULE; - - constructor(address licensingModule) { - LICENSING_MODULE = ILicensingModule(licensingModule); - } - - /// @notice Modifier for authorizing the calling entity to only the LicensingModule. - modifier onlyLicensingModule() { - if (msg.sender != address(LICENSING_MODULE)) { - revert Errors.LicensingModuleAware__CallerNotLicensingModule(); - } - _; - } -} diff --git a/contracts/modules/licensing/PILPolicyFrameworkManager.sol b/contracts/modules/licensing/PILPolicyFrameworkManager.sol index 5a9dda6f3..e6deb82a4 100644 --- a/contracts/modules/licensing/PILPolicyFrameworkManager.sol +++ b/contracts/modules/licensing/PILPolicyFrameworkManager.sol @@ -5,7 +5,7 @@ pragma solidity 0.8.23; // external import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; import { ERC165Checker } from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol"; -import { ReentrancyGuard } from "@openzeppelin/contracts/utils/ReentrancyGuard.sol"; +import { ReentrancyGuardUpgradeable } from "@openzeppelin/contracts-upgradeable/utils/ReentrancyGuardUpgradeable.sol"; // contracts import { IHookModule } from "../../interfaces/modules/base/IHookModule.sol"; @@ -25,7 +25,7 @@ contract PILPolicyFrameworkManager is IPILPolicyFrameworkManager, BasePolicyFrameworkManager, LicensorApprovalChecker, - ReentrancyGuard + ReentrancyGuardUpgradeable { using ERC165Checker for address; using Strings for *; @@ -34,20 +34,30 @@ contract PILPolicyFrameworkManager is bytes32 private constant _EMPTY_STRING_ARRAY_HASH = 0x569e75fc77c1a856f6daaf9e69d8a9566ca34aa47f9133711ce065a571af0cfd; + /// Constructor + /// @param accessController the address of the AccessController + /// @param ipAccountRegistry the address of the IPAccountRegistry + /// @param licensing the address of the LicensingModule + /// @custom:oz-upgrades-unsafe-allow constructor constructor( address accessController, address ipAccountRegistry, - address licensing, - string memory name_, - string memory licenseUrl_ + address licensing ) - BasePolicyFrameworkManager(licensing, name_, licenseUrl_) + BasePolicyFrameworkManager(licensing) LicensorApprovalChecker( accessController, ipAccountRegistry, address(ILicensingModule(licensing).LICENSE_REGISTRY()) ) - {} + { + _disableInitializers(); + } + + function initialize(string memory name, string memory licenseTextUrl) external initializer { + __BasePolicyFrameworkManager_init(name, licenseTextUrl); + __ReentrancyGuard_init(); + } /// @notice Registers a new policy to the registry /// @dev Internally, this function must generate a Licensing.Policy struct and call registerPolicy. diff --git a/contracts/modules/licensing/parameter-helpers/LicensorApprovalChecker.sol b/contracts/modules/licensing/parameter-helpers/LicensorApprovalChecker.sol index fdda0920f..3bdffba14 100644 --- a/contracts/modules/licensing/parameter-helpers/LicensorApprovalChecker.sol +++ b/contracts/modules/licensing/parameter-helpers/LicensorApprovalChecker.sol @@ -4,10 +4,12 @@ pragma solidity 0.8.23; import { AccessControlled } from "../../../access/AccessControlled.sol"; import { ILicenseRegistry } from "../../../interfaces/registries/ILicenseRegistry.sol"; +import { Initializable } from "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; + /// @title LicensorApprovalChecker /// @notice Manages the approval of derivative IP accounts by the licensor. Used to verify /// licensing terms like "Derivatives With Approval" in PIL. -abstract contract LicensorApprovalChecker is AccessControlled { +abstract contract LicensorApprovalChecker is AccessControlled, Initializable { /// @notice Emits when a derivative IP account is approved by the licensor. /// @param licenseId The ID of the license waiting for approval /// @param ipId The ID of the derivative IP to be approved @@ -15,13 +17,28 @@ abstract contract LicensorApprovalChecker is AccessControlled { /// @param approved Result of the approval event DerivativeApproved(uint256 indexed licenseId, address indexed ipId, address indexed caller, bool approved); + /// @notice Storage for derivative IP approvals. + /// @param approvals Approvals for derivative IP. + /// @dev License Id => licensor => childIpId => approved + /// @custom:storage-location erc7201:story-protocol.LicensorApprovalChecker + struct LicensorApprovalCheckerStorage { + mapping(uint256 => mapping(address => mapping(address => bool))) approvals; + } + /// @notice Returns the license registry address + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable ILicenseRegistry public immutable LICENSE_REGISTRY; - /// @notice Approvals for derivative IP. - /// @dev License Id => licensor => childIpId => approved - mapping(uint256 => mapping(address => mapping(address => bool))) private _approvals; + // keccak256(abi.encode(uint256(keccak256("story-protocol.LicensorApprovalChecker")) - 1)) + // & ~bytes32(uint256(0xff)); + bytes32 private constant LicensorApprovalCheckerStorageLocation = + 0x7a71306cccadc52d66a0a466930bd537acf0ba900f21654919d58cece4cf9500; + /// @notice Constructor function + /// @param accessController The address of the AccessController contract + /// @param ipAccountRegistry The address of the IPAccountRegistry contract + /// @param licenseRegistry The address of the LicenseRegistry contract + /// @custom:oz-upgrades-unsafe-allow constructor constructor( address accessController, address ipAccountRegistry, @@ -45,7 +62,8 @@ abstract contract LicensorApprovalChecker is AccessControlled { /// @return approved True if the derivative IP account using the license is approved function isDerivativeApproved(uint256 licenseId, address childIpId) public view returns (bool) { address licensorIpId = LICENSE_REGISTRY.licensorIpId(licenseId); - return _approvals[licenseId][licensorIpId][childIpId]; + LicensorApprovalCheckerStorage storage $ = _getLicensorApprovalCheckerStorage(); + return $.approvals[licenseId][licensorIpId][childIpId]; } /// @notice Sets the approval for a derivative IP account. @@ -60,7 +78,15 @@ abstract contract LicensorApprovalChecker is AccessControlled { address childIpId, bool approved ) internal verifyPermission(licensorIpId) { - _approvals[licenseId][licensorIpId][childIpId] = approved; + LicensorApprovalCheckerStorage storage $ = _getLicensorApprovalCheckerStorage(); + $.approvals[licenseId][licensorIpId][childIpId] = approved; emit DerivativeApproved(licenseId, licensorIpId, msg.sender, approved); } + + /// @dev Returns the storage struct of LicensorApprovalChecker. + function _getLicensorApprovalCheckerStorage() private pure returns (LicensorApprovalCheckerStorage storage $) { + assembly { + $.slot := LicensorApprovalCheckerStorageLocation + } + } } diff --git a/contracts/registries/LicenseRegistry.sol b/contracts/registries/LicenseRegistry.sol index 140336bf1..68fc55ac3 100644 --- a/contracts/registries/LicenseRegistry.sol +++ b/contracts/registries/LicenseRegistry.sol @@ -317,7 +317,6 @@ contract LicenseRegistry is ILicenseRegistry, ERC1155Upgradeable, GovernableUpgr } /// @dev Hook to authorize the upgrade according to UUPSUgradeable - /// Must be called by ProtocolRoles.UPGRADER /// @param newImplementation The address of the new implementation function _authorizeUpgrade(address newImplementation) internal override onlyProtocolAdmin {} } diff --git a/contracts/registries/ModuleRegistry.sol b/contracts/registries/ModuleRegistry.sol index a0eae7aaa..5d670d485 100644 --- a/contracts/registries/ModuleRegistry.sol +++ b/contracts/registries/ModuleRegistry.sol @@ -3,31 +3,49 @@ pragma solidity 0.8.23; import { ERC165Checker } from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol"; import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; +import { UUPSUpgradeable } from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; import { IModuleRegistry } from "../interfaces/registries/IModuleRegistry.sol"; import { Errors } from "../lib/Errors.sol"; import { IModule } from "../interfaces/modules/base/IModule.sol"; -import { Governable } from "../governance/Governable.sol"; +import { GovernableUpgradeable } from "../governance/GovernableUpgradeable.sol"; + import { MODULE_TYPE_DEFAULT } from "../lib/modules/Module.sol"; /// @title ModuleRegistry /// @notice This contract is used to register and track modules in the protocol. -contract ModuleRegistry is IModuleRegistry, Governable { +contract ModuleRegistry is IModuleRegistry, GovernableUpgradeable, UUPSUpgradeable { using Strings for *; using ERC165Checker for address; - /// @dev Returns the address of a registered module by its name. - mapping(string moduleName => address moduleAddress) internal modules; + /// @dev Storage for the ModuleRegistry. + /// @param modules The address of a registered module by its name. + /// @param moduleTypes The module type of a registered module by its address. + /// @param allModuleTypes The interface ID of a registered module type. + /// @custom:storage-location erc7201:story-protocol.ModuleRegistry + struct ModuleRegistryStorage { + mapping(string moduleName => address moduleAddress) modules; + mapping(address moduleAddress => string moduleType) moduleTypes; + mapping(string moduleType => bytes4 moduleTypeInterface) allModuleTypes; + } - /// @dev Returns the module type of a registered module by its address. - mapping(address moduleAddress => string moduleType) internal moduleTypes; + // keccak256(abi.encode(uint256(keccak256("story-protocol.ModuleRegistry")) - 1)) & ~bytes32(uint256(0xff)); + bytes32 private constant ModuleRegistryStorageLocation = + 0xa17d78ae7aee011aefa3f1388acb36741284b44eb3fcffe23ecc3a736eaa2700; - /// @dev Returns the interface ID of a registered module type. - mapping(string moduleType => bytes4 moduleTypeInterface) internal allModuleTypes; + /// @custom:oz-upgrades-unsafe-allow constructor + constructor() { + _disableInitializers(); + } + + /// @notice Initializes the ModuleRegistry contract as per the Initializable contract. + /// @param governance_ The address of the governance. + function initialize(address governance_) public initializer { + __GovernableUpgradeable_init(governance_); + __UUPSUpgradeable_init(); - constructor(address governance) Governable(governance) { // Register the default module types - allModuleTypes[MODULE_TYPE_DEFAULT] = type(IModule).interfaceId; + _getModuleRegistryStorage().allModuleTypes[MODULE_TYPE_DEFAULT] = type(IModule).interfaceId; } /// @notice Registers a new module type in the registry associate with an interface. @@ -35,16 +53,17 @@ contract ModuleRegistry is IModuleRegistry, Governable { /// @param name The name of the module type to be registered. /// @param interfaceId The interface ID associated with the module type. function registerModuleType(string memory name, bytes4 interfaceId) external override onlyProtocolAdmin { + ModuleRegistryStorage storage $ = _getModuleRegistryStorage(); if (interfaceId == 0) { revert Errors.ModuleRegistry__InterfaceIdZero(); } if (bytes(name).length == 0) { revert Errors.ModuleRegistry__NameEmptyString(); } - if (allModuleTypes[name] != 0) { + if ($.allModuleTypes[name] != 0) { revert Errors.ModuleRegistry__ModuleTypeAlreadyRegistered(); } - allModuleTypes[name] = interfaceId; + $.allModuleTypes[name] = interfaceId; } /// @notice Removes a module type from the registry. @@ -54,10 +73,11 @@ contract ModuleRegistry is IModuleRegistry, Governable { if (bytes(name).length == 0) { revert Errors.ModuleRegistry__NameEmptyString(); } - if (allModuleTypes[name] == 0) { + ModuleRegistryStorage storage $ = _getModuleRegistryStorage(); + if ($.allModuleTypes[name] == 0) { revert Errors.ModuleRegistry__ModuleTypeNotRegistered(); } - delete allModuleTypes[name]; + delete $.allModuleTypes[name]; } /// @notice Registers a new module in the registry. @@ -87,14 +107,14 @@ contract ModuleRegistry is IModuleRegistry, Governable { if (bytes(name).length == 0) { revert Errors.ModuleRegistry__NameEmptyString(); } - - if (modules[name] == address(0)) { + ModuleRegistryStorage storage $ = _getModuleRegistryStorage(); + if ($.modules[name] == address(0)) { revert Errors.ModuleRegistry__ModuleNotRegistered(); } - address module = modules[name]; - delete modules[name]; - delete moduleTypes[module]; + address module = $.modules[name]; + delete $.modules[name]; + delete $.moduleTypes[module]; emit ModuleRemoved(name, module); } @@ -103,33 +123,35 @@ contract ModuleRegistry is IModuleRegistry, Governable { /// @param moduleAddress The address of the module. /// @return isRegistered True if the module is registered, false otherwise. function isRegistered(address moduleAddress) external view returns (bool) { - return bytes(moduleTypes[moduleAddress]).length > 0; + ModuleRegistryStorage storage $ = _getModuleRegistryStorage(); + return bytes($.moduleTypes[moduleAddress]).length > 0; } /// @notice Returns the address of a module. /// @param name The name of the module. /// @return The address of the module. function getModule(string memory name) external view returns (address) { - return modules[name]; + return _getModuleRegistryStorage().modules[name]; } /// @notice Returns the module type of a given module address. /// @param moduleAddress The address of the module. /// @return The type of the module as a string. function getModuleType(address moduleAddress) external view returns (string memory) { - return moduleTypes[moduleAddress]; + return _getModuleRegistryStorage().moduleTypes[moduleAddress]; } /// @notice Returns the interface ID associated with a given module type. /// @param moduleType The type of the module as a string. /// @return The interface ID of the module type as bytes4. function getModuleTypeInterfaceId(string memory moduleType) external view returns (bytes4) { - return allModuleTypes[moduleType]; + return _getModuleRegistryStorage().allModuleTypes[moduleType]; } /// @dev Registers a new module in the registry. // solhint-disable code-complexity function _registerModule(string memory name, address moduleAddress, string memory moduleType) internal { + ModuleRegistryStorage storage $ = _getModuleRegistryStorage(); if (moduleAddress == address(0)) { revert Errors.ModuleRegistry__ModuleAddressZeroAddress(); } @@ -139,28 +161,39 @@ contract ModuleRegistry is IModuleRegistry, Governable { if (moduleAddress.code.length == 0) { revert Errors.ModuleRegistry__ModuleAddressNotContract(); } - if (bytes(moduleTypes[moduleAddress]).length > 0) { + if (bytes($.moduleTypes[moduleAddress]).length > 0) { revert Errors.ModuleRegistry__ModuleAlreadyRegistered(); } if (bytes(name).length == 0) { revert Errors.ModuleRegistry__NameEmptyString(); } - if (modules[name] != address(0)) { + if ($.modules[name] != address(0)) { revert Errors.ModuleRegistry__NameAlreadyRegistered(); } if (!IModule(moduleAddress).name().equal(name)) { revert Errors.ModuleRegistry__NameDoesNotMatch(); } - bytes4 moduleTypeInterfaceId = allModuleTypes[moduleType]; + bytes4 moduleTypeInterfaceId = $.allModuleTypes[moduleType]; if (moduleTypeInterfaceId == 0) { revert Errors.ModuleRegistry__ModuleTypeNotRegistered(); } if (!moduleAddress.supportsInterface(moduleTypeInterfaceId)) { revert Errors.ModuleRegistry__ModuleNotSupportExpectedModuleTypeInterfaceId(); } - modules[name] = moduleAddress; - moduleTypes[moduleAddress] = moduleType; + $.modules[name] = moduleAddress; + $.moduleTypes[moduleAddress] = moduleType; emit ModuleAdded(name, moduleAddress, moduleTypeInterfaceId, moduleType); } + + /// @dev Returns the storage struct of the ModuleRegistry. + function _getModuleRegistryStorage() private pure returns (ModuleRegistryStorage storage $) { + assembly { + $.slot := ModuleRegistryStorageLocation + } + } + + /// @dev Hook to authorize the upgrade according to UUPSUgradeable + /// @param newImplementation The address of the new implementation + function _authorizeUpgrade(address newImplementation) internal override onlyProtocolAdmin {} } diff --git a/script/foundry/deployment/Main.s.sol b/script/foundry/deployment/Main.s.sol index 8eb180cdd..3c77f2d7e 100644 --- a/script/foundry/deployment/Main.s.sol +++ b/script/foundry/deployment/Main.s.sol @@ -115,7 +115,12 @@ contract Main is Script, BroadcastManager, JsonDeploymentHandler { function run() public { _beginBroadcast(); // BroadcastManager.s.sol - bool configByMultisig = vm.envBool("DEPLOYMENT_CONFIG_BY_MULTISIG"); + bool configByMultisig; + try vm.envBool("DEPLOYMENT_CONFIG_BY_MULTISIG") returns (bool mult) { + configByMultisig = mult; + } catch { + configByMultisig = false; + } console2.log("configByMultisig:", configByMultisig); if (configByMultisig) { @@ -158,7 +163,18 @@ contract Main is Script, BroadcastManager, JsonDeploymentHandler { contractKey = "AccessController"; _predeploy(contractKey); - accessController = new AccessController(address(governance)); + + address impl = address(new AccessController()); + accessController = AccessController( + TestProxyHelper.deployUUPSProxy( + impl, + abi.encodeCall( + AccessController.initialize, + address(governance) + ) + ) + ); + impl = address(0); // Make sure we don't deploy wrong impl _postdeploy(contractKey, address(accessController)); contractKey = "IPAccountImpl"; @@ -168,7 +184,17 @@ contract Main is Script, BroadcastManager, JsonDeploymentHandler { contractKey = "ModuleRegistry"; _predeploy(contractKey); - moduleRegistry = new ModuleRegistry(address(governance)); + impl = address(new ModuleRegistry()); + moduleRegistry = ModuleRegistry( + TestProxyHelper.deployUUPSProxy( + impl, + abi.encodeCall( + AccessController.initialize, + address(governance) + ) + ) + ); + impl = address(0); // Make sure we don't deploy wrong impl _postdeploy(contractKey, address(moduleRegistry)); contractKey = "IPAccountRegistry"; @@ -207,7 +233,7 @@ contract Main is Script, BroadcastManager, JsonDeploymentHandler { contractKey = "LicenseRegistry"; _predeploy(contractKey); - address impl = address(new LicenseRegistry()); + impl = address(new LicenseRegistry()); licenseRegistry = LicenseRegistry( TestProxyHelper.deployUUPSProxy( impl, @@ -219,17 +245,31 @@ contract Main is Script, BroadcastManager, JsonDeploymentHandler { ) ) ); + impl = address(0); // Make sure we don't deploy wrong impl _postdeploy(contractKey, address(licenseRegistry)); contractKey = "LicensingModule"; _predeploy(contractKey); - licensingModule = new LicensingModule( - address(accessController), - address(ipAccountRegistry), - address(royaltyModule), - address(licenseRegistry), - address(disputeModule) + + impl = address( + new LicensingModule( + address(accessController), + address(ipAccountRegistry), + address(royaltyModule), + address(licenseRegistry), + address(disputeModule) + ) ); + licensingModule = LicensingModule( + TestProxyHelper.deployUUPSProxy( + impl, + abi.encodeCall( + LicensingModule.initialize, + address(governance) + ) + ) + ); + impl = address(0); // Make sure we don't deploy wrong impl _postdeploy(contractKey, address(licensingModule)); contractKey = "IPResolver"; @@ -273,13 +313,26 @@ contract Main is Script, BroadcastManager, JsonDeploymentHandler { _postdeploy(contractKey, address(ancestorsVaultImpl)); _predeploy("PILPolicyFrameworkManager"); - pilPfm = new PILPolicyFrameworkManager( - address(accessController), - address(ipAccountRegistry), - address(licensingModule), - "pil", - "https://github.com/storyprotocol/protocol-core/blob/main/PIL-Beta-2024-02.pdf" + impl = address( + new PILPolicyFrameworkManager( + address(accessController), + address(ipAccountRegistry), + address(licensingModule) + ) + ); + pilPfm = PILPolicyFrameworkManager( + TestProxyHelper.deployUUPSProxy( + impl, + abi.encodeCall( + PILPolicyFrameworkManager.initialize, + ( + "pil", + "https://github.com/storyprotocol/protocol-core/blob/main/PIL-Beta-2024-02.pdf" + ) + ) + ) ); + impl = address(0); // Make sure we don't deploy wrong impl _postdeploy("PILPolicyFrameworkManager", address(pilPfm)); // @@ -319,7 +372,7 @@ contract Main is Script, BroadcastManager, JsonDeploymentHandler { } function _configureAccessController() private { - accessController.initialize(address(ipAccountRegistry), address(moduleRegistry)); + accessController.setAddresses(address(ipAccountRegistry), address(moduleRegistry)); accessController.setGlobalPermission( address(ipAssetRegistry), diff --git a/script/foundry/utils/ERC7201Helper.s.sol b/script/foundry/utils/ERC7201Helper.s.sol deleted file mode 100644 index 40886d21b..000000000 --- a/script/foundry/utils/ERC7201Helper.s.sol +++ /dev/null @@ -1,63 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.23; - -import { Script } from "forge-std/Script.sol"; -import { console2 } from "forge-std/console2.sol"; - -/// @title ERC7201 Helper Script -/// @author Raul Martinez (@Ramarti) -/// @notice This script logs the boilerplate code for ERC7201 storage location and getter function, to -/// help developers implement the ERC7201 interface in their contracts. -/// Thanks Mikhail Vladimirov for bytes32 to hex string conversion functions. -/// https://stackoverflow.com/questions/67893318/solidity-how-to-represent-bytes32-as-string -contract ERC7201HelperScript is Script { - - string constant NAMESPACE = "story-protocol"; - string constant CONTRACT_NAME = "MockLicenseRegistryV2"; - - function run() external { - bytes memory erc7201Key = abi.encodePacked(NAMESPACE,".", CONTRACT_NAME); - bytes32 hash = keccak256(abi.encode(uint256(keccak256(erc7201Key)) - 1)) & ~bytes32(uint256(0xff)); - - // Log natspec and storage struct - console2.log(string(abi.encodePacked("/// @custom:storage-location erc7201:", erc7201Key))); - console2.log(string(abi.encodePacked("struct ", CONTRACT_NAME, "Storage {"))); - console2.log(" // Write storage variables here..."); - console2.log(string(abi.encodePacked("}"))); - console2.log(""); - - // Log ERC7201 comment and storage location - console2.log(string(abi.encodePacked("// keccak256(abi.encode(uint256(keccak256(",'"', erc7201Key,'"',")) - 1)) & ~bytes32(uint256(0xff));"))); - console2.log(string(abi.encodePacked("bytes32 private constant ", CONTRACT_NAME, "StorageLocation = ", toHexString(hash), ";"))); - console2.log(""); - - // Log getter function - console2.log(string(abi.encodePacked("function _get", CONTRACT_NAME, "Storage() private pure returns (", CONTRACT_NAME, "Storage storage $) {"))); - console2.log(string(abi.encodePacked(" assembly {"))); - console2.log(string(abi.encodePacked(" $.slot := ", CONTRACT_NAME, "StorageLocation"))); - console2.log(string(abi.encodePacked(" }"))); - console2.log(string(abi.encodePacked("}"))); - - } - - function toHex16(bytes16 data) internal pure returns (bytes32 result) { - result = bytes32 (data) & 0xFFFFFFFFFFFFFFFF000000000000000000000000000000000000000000000000 | - (bytes32 (data) & 0x0000000000000000FFFFFFFFFFFFFFFF00000000000000000000000000000000) >> 64; - result = result & 0xFFFFFFFF000000000000000000000000FFFFFFFF000000000000000000000000 | - (result & 0x00000000FFFFFFFF000000000000000000000000FFFFFFFF0000000000000000) >> 32; - result = result & 0xFFFF000000000000FFFF000000000000FFFF000000000000FFFF000000000000 | - (result & 0x0000FFFF000000000000FFFF000000000000FFFF000000000000FFFF00000000) >> 16; - result = result & 0xFF000000FF000000FF000000FF000000FF000000FF000000FF000000FF000000 | - (result & 0x00FF000000FF000000FF000000FF000000FF000000FF000000FF000000FF0000) >> 8; - result = (result & 0xF000F000F000F000F000F000F000F000F000F000F000F000F000F000F000F000) >> 4 | - (result & 0x0F000F000F000F000F000F000F000F000F000F000F000F000F000F000F000F00) >> 8; - result = bytes32 (0x3030303030303030303030303030303030303030303030303030303030303030 + - uint256 (result) + - (uint256 (result) + 0x0606060606060606060606060606060606060606060606060606060606060606 >> 4 & - 0x0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F) * 39); - } - - function toHexString(bytes32 data) internal pure returns (string memory) { - return string (abi.encodePacked ("0x", toHex16 (bytes16 (data)), toHex16 (bytes16 (data << 128)))); - } -} \ No newline at end of file diff --git a/script/foundry/utils/upgrades/ERC7201Helper.s.sol b/script/foundry/utils/upgrades/ERC7201Helper.s.sol new file mode 100644 index 000000000..b7904c969 --- /dev/null +++ b/script/foundry/utils/upgrades/ERC7201Helper.s.sol @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import { Script } from "forge-std/Script.sol"; +import { console2 } from "forge-std/console2.sol"; + +/// @title ERC7201 Helper Script +/// @notice This script logs the boilerplate code for ERC7201 storage location and getter function, to +/// help developers implement the ERC7201 interface in their contracts. +/// Thanks Mikhail Vladimirov for bytes32 to hex string conversion functions. +/// https://stackoverflow.com/questions/67893318/solidity-how-to-represent-bytes32-as-string +contract ERC7201HelperScript is Script { + string constant NAMESPACE = "story-protocol"; + string constant CONTRACT_NAME = "LicensorApprovalChecker"; + + function run() external { + bytes memory erc7201Key = abi.encodePacked(NAMESPACE, ".", CONTRACT_NAME); + bytes32 hash = keccak256(abi.encode(uint256(keccak256(erc7201Key)) - 1)) & ~bytes32(uint256(0xff)); + + // Log natspec and storage struct + console2.log(string(abi.encodePacked("/// @custom:storage-location erc7201:", erc7201Key))); + console2.log(string(abi.encodePacked("struct ", CONTRACT_NAME, "Storage {"))); + console2.log(" // Write storage variables here..."); + console2.log(string(abi.encodePacked("}"))); + console2.log(""); + + // Log ERC7201 comment and storage location + console2.log( + string( + abi.encodePacked( + "// keccak256(abi.encode(uint256(keccak256(", + '"', + erc7201Key, + '"', + ")) - 1)) & ~bytes32(uint256(0xff));" + ) + ) + ); + console2.log( + string( + abi.encodePacked( + "bytes32 private constant ", + CONTRACT_NAME, + "StorageLocation = ", + toHexString(hash), + ";" + ) + ) + ); + console2.log(""); + + // Log getter function + console2.log(string(abi.encodePacked("/// @dev Returns the storage struct of ", CONTRACT_NAME, "."))); + console2.log( + string( + abi.encodePacked( + "function _get", + CONTRACT_NAME, + "Storage() private pure returns (", + CONTRACT_NAME, + "Storage storage $) {" + ) + ) + ); + console2.log(string(abi.encodePacked(" assembly {"))); + console2.log(string(abi.encodePacked(" $.slot := ", CONTRACT_NAME, "StorageLocation"))); + console2.log(string(abi.encodePacked(" }"))); + console2.log(string(abi.encodePacked("}"))); + } + + function toHex16(bytes16 data) internal pure returns (bytes32 result) { + result = + (bytes32(data) & 0xFFFFFFFFFFFFFFFF000000000000000000000000000000000000000000000000) | + ((bytes32(data) & 0x0000000000000000FFFFFFFFFFFFFFFF00000000000000000000000000000000) >> 64); + result = + (result & 0xFFFFFFFF000000000000000000000000FFFFFFFF000000000000000000000000) | + ((result & 0x00000000FFFFFFFF000000000000000000000000FFFFFFFF0000000000000000) >> 32); + result = + (result & 0xFFFF000000000000FFFF000000000000FFFF000000000000FFFF000000000000) | + ((result & 0x0000FFFF000000000000FFFF000000000000FFFF000000000000FFFF00000000) >> 16); + result = + (result & 0xFF000000FF000000FF000000FF000000FF000000FF000000FF000000FF000000) | + ((result & 0x00FF000000FF000000FF000000FF000000FF000000FF000000FF000000FF0000) >> 8); + result = + ((result & 0xF000F000F000F000F000F000F000F000F000F000F000F000F000F000F000F000) >> 4) | + ((result & 0x0F000F000F000F000F000F000F000F000F000F000F000F000F000F000F000F00) >> 8); + result = bytes32( + 0x3030303030303030303030303030303030303030303030303030303030303030 + + uint256(result) + + (((uint256(result) + 0x0606060606060606060606060606060606060606060606060606060606060606) >> 4) & + 0x0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F) * + 39 + ); + } + + function toHexString(bytes32 data) internal pure returns (string memory) { + return string(abi.encodePacked("0x", toHex16(bytes16(data)), toHex16(bytes16(data << 128)))); + } +} diff --git a/test/foundry/IPAccountMetaTx.t.sol b/test/foundry/IPAccountMetaTx.t.sol index 391ba2b12..13410026b 100644 --- a/test/foundry/IPAccountMetaTx.t.sol +++ b/test/foundry/IPAccountMetaTx.t.sol @@ -29,7 +29,6 @@ contract IPAccountMetaTxTest is BaseTest { buildDeployRegistryCondition(DeployRegistryCondition({ moduleRegistry: true, licenseRegistry: false })); deployConditionally(); postDeploymentSetup(); - ownerPrivateKey = 0xA11111; callerPrivateKey = 0xB22222; owner = vm.addr(ownerPrivateKey); diff --git a/test/foundry/integration/flows/disputes/Disputes.t.sol b/test/foundry/integration/flows/disputes/Disputes.t.sol index 8c65f3fed..4e7440db3 100644 --- a/test/foundry/integration/flows/disputes/Disputes.t.sol +++ b/test/foundry/integration/flows/disputes/Disputes.t.sol @@ -23,7 +23,7 @@ contract Flows_Integration_Disputes is BaseIntegration { super.setUp(); // Register PIL Framework - _deployLFM_PIL(); + _setPILPolicyFrameworkManager(); // Register a License _mapPILPolicySimple({ diff --git a/test/foundry/integration/flows/licensing/LicensingScenarios.t.sol b/test/foundry/integration/flows/licensing/LicensingScenarios.t.sol index a383b13a0..90b3dfc85 100644 --- a/test/foundry/integration/flows/licensing/LicensingScenarios.t.sol +++ b/test/foundry/integration/flows/licensing/LicensingScenarios.t.sol @@ -23,7 +23,7 @@ contract Licensing_Scenarios is BaseIntegration { super.setUp(); // Register PIL Framework - _deployLFM_PIL(); + _setPILPolicyFrameworkManager(); // Register an original work with both policies set mockNFT.mintId(u.alice, 1); diff --git a/test/foundry/integration/flows/royalty/Royalty.t.sol b/test/foundry/integration/flows/royalty/Royalty.t.sol index 156707852..c4490f2ef 100644 --- a/test/foundry/integration/flows/royalty/Royalty.t.sol +++ b/test/foundry/integration/flows/royalty/Royalty.t.sol @@ -28,7 +28,7 @@ contract Flows_Integration_Disputes is BaseIntegration { super.setUp(); // Register PIL Framework - _deployLFM_PIL(); + _setPILPolicyFrameworkManager(); royaltyPolicyAddr = address(royaltyPolicyLAP); mintingFeeToken = address(erc20); diff --git a/test/foundry/mocks/access/MockAccessController.sol b/test/foundry/mocks/access/MockAccessController.sol index 01eb479cc..05d708e8e 100644 --- a/test/foundry/mocks/access/MockAccessController.sol +++ b/test/foundry/mocks/access/MockAccessController.sol @@ -9,7 +9,7 @@ import { AccessPermission } from "contracts/lib/AccessPermission.sol"; contract MockAccessController is IAccessController { bool public isAllowed = true; - function initialize(address ipAccountRegistry, address moduleRegistry) external {} + function setAddresses(address ipAccountRegistry, address moduleRegistry) external {} function setAllowed(bool _isAllowed) external { isAllowed = _isAllowed; diff --git a/test/foundry/mocks/licensing/MockPolicyFrameworkManager.sol b/test/foundry/mocks/licensing/MockPolicyFrameworkManager.sol index 9e95bd596..1bd11140a 100644 --- a/test/foundry/mocks/licensing/MockPolicyFrameworkManager.sol +++ b/test/foundry/mocks/licensing/MockPolicyFrameworkManager.sol @@ -24,11 +24,12 @@ contract MockPolicyFrameworkManager is BasePolicyFrameworkManager { event MockPolicyAdded(uint256 indexed policyId, MockPolicy policy); - constructor( - MockPolicyFrameworkConfig memory conf - ) BasePolicyFrameworkManager(conf.licensingModule, conf.name, conf.licenseUrl) { + /// @custom:oz-upgrades-unsafe-allow constructor + constructor(MockPolicyFrameworkConfig memory conf) BasePolicyFrameworkManager(conf.licensingModule) { config = conf; royaltyPolicy = conf.royaltyPolicy; + _getBasePolicyFrameworkManagerStorage().name = conf.name; + _getBasePolicyFrameworkManagerStorage().licenseTextUrl = conf.licenseUrl; } function registerPolicy(MockPolicy calldata mockPolicy) external returns (uint256 policyId) { diff --git a/test/foundry/modules/licensing/LicensingModule.t.sol b/test/foundry/modules/licensing/LicensingModule.t.sol index 32631b40c..c9b85f0eb 100644 --- a/test/foundry/modules/licensing/LicensingModule.t.sol +++ b/test/foundry/modules/licensing/LicensingModule.t.sol @@ -10,7 +10,7 @@ import { AccessPermission } from "contracts/lib/AccessPermission.sol"; import { Errors } from "contracts/lib/Errors.sol"; import { Licensing } from "contracts/lib/Licensing.sol"; import { RegisterPILPolicyParams } from "contracts/interfaces/modules/licensing/IPILPolicyFrameworkManager.sol"; -import { PILPolicyFrameworkManager, PILPolicy } from "contracts/modules/licensing/PILPolicyFrameworkManager.sol"; +import { PILPolicy } from "contracts/modules/licensing/PILPolicyFrameworkManager.sol"; // test // solhint-disable-next-line max-line-length @@ -26,7 +26,6 @@ contract LicensingModuleTest is BaseTest { MockAccessController internal mockAccessController = new MockAccessController(); MockPolicyFrameworkManager internal mockPFM; - PILPolicyFrameworkManager internal pilManager; MockERC721 internal nft = new MockERC721("MockERC721"); MockERC721 internal gatedNftFoo = new MockERC721{ salt: bytes32(uint256(1)) }("GatedNftFoo"); @@ -66,14 +65,6 @@ contract LicensingModuleTest is BaseTest { }) ); - pilManager = new PILPolicyFrameworkManager( - address(mockAccessController), - address(ipAccountRegistry), - address(licensingModule), - "PILPolicyFrameworkManager", - licenseUrl - ); - // Create IPAccounts nft.mintId(ipOwner, 1); nft.mintId(ipOwner, 2); @@ -103,16 +94,8 @@ contract LicensingModuleTest is BaseTest { } function test_LicensingModule_registerPFM() public { - PILPolicyFrameworkManager pfm1 = new PILPolicyFrameworkManager( - address(accessController), - address(ipAccountRegistry), - address(licensingModule), - "PILPolicyFrameworkManager", - licenseUrl - ); - - licensingModule.registerPolicyFrameworkManager(address(pfm1)); - assertTrue(licensingModule.isFrameworkRegistered(address(pfm1))); + licensingModule.registerPolicyFrameworkManager(_deployPILFramework("license Url")); + assertTrue(licensingModule.isFrameworkRegistered(address(_pilFramework()))); } function test_LicensingModule_registerPFM_revert_invalidPolicyFramework() public { @@ -121,16 +104,10 @@ contract LicensingModuleTest is BaseTest { } function test_LicensingModule_registerPFM_revert_emptyLicenseUrl() public { - PILPolicyFrameworkManager pfm1 = new PILPolicyFrameworkManager( - address(accessController), - address(ipAccountRegistry), - address(licensingModule), - "PILPolicyFrameworkManager", - "" - ); + _deployPILFramework(""); vm.expectRevert(Errors.LicensingModule__EmptyLicenseUrl.selector); - licensingModule.registerPolicyFrameworkManager(address(pfm1)); + licensingModule.registerPolicyFrameworkManager(address(_pilFramework())); } function test_LicensingModule_registerPolicy_revert_frameworkNotFound() public { @@ -585,7 +562,7 @@ contract LicensingModuleTest is BaseTest { } function test_LicensingModule_revert_HookVerifyFail() public { - licensingModule.registerPolicyFrameworkManager(address(pilManager)); + _setPILPolicyFrameworkManager(); PILPolicy memory policyData = PILPolicy({ attribution: true, @@ -612,7 +589,7 @@ contract LicensingModuleTest is BaseTest { policyData.territories[0] = "territory1"; policyData.distributionChannels[0] = "distributionChannel1"; - uint256 policyId = pilManager.registerPolicy( + uint256 policyId = _pilFramework().registerPolicy( RegisterPILPolicyParams({ transferable: true, royaltyPolicy: address(mockRoyaltyPolicyLAP), diff --git a/test/foundry/modules/licensing/PILPolicyFramework.derivation.t.sol b/test/foundry/modules/licensing/PILPolicyFramework.derivation.t.sol index 57dc42161..e01139cfa 100644 --- a/test/foundry/modules/licensing/PILPolicyFramework.derivation.t.sol +++ b/test/foundry/modules/licensing/PILPolicyFramework.derivation.t.sol @@ -5,28 +5,14 @@ import { IAccessController } from "contracts/interfaces/IAccessController.sol"; import { ILicensingModule } from "contracts/interfaces/modules/licensing/ILicensingModule.sol"; import { IRoyaltyModule } from "contracts/interfaces/modules/royalty/IRoyaltyModule.sol"; import { Errors } from "contracts/lib/Errors.sol"; -import { PILPolicyFrameworkManager } from "contracts/modules/licensing/PILPolicyFrameworkManager.sol"; import { BaseTest } from "test/foundry/utils/BaseTest.t.sol"; contract PILPolicyFrameworkCompatibilityTest is BaseTest { - PILPolicyFrameworkManager internal pilFramework; - string internal licenseUrl = "https://example.com/license"; address internal ipId1; address internal ipId2; - modifier withPILPolicySimple( - string memory name, - bool commercial, - bool derivatives, - bool reciprocal - ) { - _mapPILPolicySimple(name, commercial, derivatives, reciprocal, 100); - _addPILPolicyFromMapping(name, address(pilFramework)); - _; - } - modifier withAliceOwningDerivativeIp2(string memory policyName) { // Must add the policy first to set the royalty policy (if policy is commercial) // Otherwise, minting license will fail because there's no royalty policy set for license policy, @@ -59,15 +45,7 @@ contract PILPolicyFrameworkCompatibilityTest is BaseTest { licensingModule = ILicensingModule(getLicensingModule()); royaltyModule = IRoyaltyModule(getRoyaltyModule()); - pilFramework = new PILPolicyFrameworkManager( - address(accessController), - address(ipAccountRegistry), - address(licensingModule), - "PILPolicyFrameworkManager", - licenseUrl - ); - - licensingModule.registerPolicyFrameworkManager(address(pilFramework)); + _setPILPolicyFrameworkManager(); mockNFT.mintId(bob, 1); mockNFT.mintId(alice, 2); @@ -182,7 +160,7 @@ contract PILPolicyFrameworkCompatibilityTest is BaseTest { _mapPILPolicySimple("other_policy", true, true, false, 100); _getMappedPilPolicy("other_policy").attribution = false; - _addPILPolicyFromMapping("other_policy", address(pilFramework)); + _addPILPolicyFromMapping("other_policy", address(_pilFramework())); vm.expectRevert(Errors.LicensingModule__DerivativesCannotAddPolicy.selector); vm.prank(alice); diff --git a/test/foundry/modules/licensing/PILPolicyFramework.multi-parent.sol b/test/foundry/modules/licensing/PILPolicyFramework.multi-parent.sol index 4c4127b7f..c8cbf7544 100644 --- a/test/foundry/modules/licensing/PILPolicyFramework.multi-parent.sol +++ b/test/foundry/modules/licensing/PILPolicyFramework.multi-parent.sol @@ -9,12 +9,10 @@ import { Licensing } from "contracts/lib/Licensing.sol"; import { PILFrameworkErrors } from "contracts/lib/PILFrameworkErrors.sol"; // solhint-disable-next-line max-line-length import { RegisterPILPolicyParams } from "contracts/interfaces/modules/licensing/IPILPolicyFrameworkManager.sol"; -import { PILPolicyFrameworkManager } from "contracts/modules/licensing/PILPolicyFrameworkManager.sol"; import { BaseTest } from "test/foundry/utils/BaseTest.t.sol"; contract PILPolicyFrameworkMultiParentTest is BaseTest { - PILPolicyFrameworkManager internal pilFramework; string internal licenseUrl = "https://example.com/license"; address internal ipId1; address internal ipId2; @@ -25,17 +23,6 @@ contract PILPolicyFrameworkMultiParentTest is BaseTest { mapping(address => address) internal ipIdToOwner; - modifier withPILPolicySimple( - string memory name, - bool commercial, - bool derivatives, - bool reciprocal - ) { - _mapPILPolicySimple(name, commercial, derivatives, reciprocal, 100); - _addPILPolicyFromMapping(name, address(pilFramework)); - _; - } - modifier withLicense( string memory policyName, address ipId, @@ -68,15 +55,9 @@ contract PILPolicyFrameworkMultiParentTest is BaseTest { licensingModule = ILicensingModule(getLicensingModule()); royaltyModule = IRoyaltyModule(getRoyaltyModule()); - pilFramework = new PILPolicyFrameworkManager( - address(accessController), - address(ipAccountRegistry), - address(licensingModule), - "PILPolicyFrameworkManager", - licenseUrl - ); + _setPILPolicyFrameworkManager(); - licensingModule.registerPolicyFrameworkManager(address(pilFramework)); + licensingModule.registerPolicyFrameworkManager(address(_pilFramework())); mockNFT.mintId(bob, 1); mockNFT.mintId(bob, 2); @@ -139,7 +120,7 @@ contract PILPolicyFrameworkMultiParentTest is BaseTest { // Save a new policy (change some value to change the policyId) _mapPILPolicySimple("other", true, true, true, 100); _getMappedPilPolicy("other").attribution = !_getMappedPilPolicy("other").attribution; - _addPILPolicyFromMapping("other", address(pilFramework)); + _addPILPolicyFromMapping("other", address(_pilFramework())); vm.prank(ipId3); licenses.push(licensingModule.mintLicense(_getPilPolicyId("other"), ipId3, 1, alice, "")); @@ -422,11 +403,11 @@ contract PILPolicyFrameworkMultiParentTest is BaseTest { RegisterPILPolicyParams memory inputA, RegisterPILPolicyParams memory inputB ) internal returns (uint256 polAId, uint256 polBId) { - polAId = pilFramework.registerPolicy(inputA); + polAId = _pilFramework().registerPolicy(inputA); vm.prank(ipId1); licenses.push(licensingModule.mintLicense(polAId, ipId1, 1, alice, "")); - polBId = pilFramework.registerPolicy(inputB); + polBId = _pilFramework().registerPolicy(inputB); vm.prank(ipId2); licenses.push(licensingModule.mintLicense(polBId, ipId2, 2, alice, "")); } diff --git a/test/foundry/modules/licensing/PILPolicyFramework.t.sol b/test/foundry/modules/licensing/PILPolicyFramework.t.sol index e18758902..7a25aa260 100644 --- a/test/foundry/modules/licensing/PILPolicyFramework.t.sol +++ b/test/foundry/modules/licensing/PILPolicyFramework.t.sol @@ -7,15 +7,12 @@ import { Errors } from "contracts/lib/Errors.sol"; import { PILFrameworkErrors } from "contracts/lib/PILFrameworkErrors.sol"; // solhint-disable-next-line max-line-length import { PILPolicy, RegisterPILPolicyParams } from "contracts/interfaces/modules/licensing/IPILPolicyFrameworkManager.sol"; -import { PILPolicyFrameworkManager } from "contracts/modules/licensing/PILPolicyFrameworkManager.sol"; import { MockERC721 } from "test/foundry/mocks/token/MockERC721.sol"; import { MockTokenGatedHook } from "test/foundry/mocks/MockTokenGatedHook.sol"; import { BaseTest } from "test/foundry/utils/BaseTest.t.sol"; contract PILPolicyFrameworkTest is BaseTest { - PILPolicyFrameworkManager internal pilFramework; - string public licenseUrl = "https://example.com/license"; address public ipId1; address public ipId2; @@ -44,15 +41,9 @@ contract PILPolicyFrameworkTest is BaseTest { accessController = IAccessController(getAccessController()); licensingModule = ILicensingModule(getLicensingModule()); - pilFramework = new PILPolicyFrameworkManager( - address(accessController), - address(ipAccountRegistry), - address(licensingModule), - "PILPolicyFrameworkManager", - licenseUrl - ); + _setPILPolicyFrameworkManager(); - licensingModule.registerPolicyFrameworkManager(address(pilFramework)); + licensingModule.registerPolicyFrameworkManager(address(_pilFramework())); mockNFT.mintId(alice, 1); mockNFT.mintId(alice, 2); @@ -76,8 +67,8 @@ contract PILPolicyFrameworkTest is BaseTest { RegisterPILPolicyParams memory inputA = _getMappedPilParams("pol_a"); inputA.policy.territories = territories; inputA.policy.distributionChannels = distributionChannels; - uint256 policyId = pilFramework.registerPolicy(inputA); - PILPolicy memory policy = pilFramework.getPILPolicy(policyId); + uint256 policyId = _pilFramework().registerPolicy(inputA); + PILPolicy memory policy = _pilFramework().getPILPolicy(policyId); assertEq(keccak256(abi.encode(policy)), keccak256(abi.encode(inputA.policy))); } @@ -99,7 +90,7 @@ contract PILPolicyFrameworkTest is BaseTest { }); vm.prank(address(licensingModule)); - bool verified = pilFramework.verifyLink(0, alice, ipId1, address(0), abi.encode(policyData)); + bool verified = _pilFramework().verifyLink(0, alice, ipId1, address(0), abi.encode(policyData)); assertFalse(verified); } @@ -121,7 +112,7 @@ contract PILPolicyFrameworkTest is BaseTest { }); vm.prank(address(licensingModule)); - bool verified = pilFramework.verifyMint(alice, false, ipId1, alice, 2, abi.encode(policyData)); + bool verified = _pilFramework().verifyMint(alice, false, ipId1, alice, 2, abi.encode(policyData)); assertFalse(verified); } @@ -143,13 +134,13 @@ contract PILPolicyFrameworkTest is BaseTest { }); vm.prank(address(licensingModule)); - bool verified = pilFramework.verifyMint(alice, false, ipId1, alice, 2, abi.encode(policyData)); + bool verified = _pilFramework().verifyMint(alice, false, ipId1, alice, 2, abi.encode(policyData)); assertFalse(verified); } function test_PILPolicyFrameworkManager__getAggregator_revert_emptyAggregator() public { vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__RightsNotFound.selector); - pilFramework.getAggregator(ipId1); + _pilFramework().getAggregator(ipId1); } ///////////////////////////////////////////////////////////// @@ -170,7 +161,7 @@ contract PILPolicyFrameworkTest is BaseTest { // CHECK: commercialAttribution = true should revert inputA.policy.commercialAttribution = true; vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__CommercialDisabled_CantAddAttribution.selector); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); // reset inputA.policy.commercialAttribution = false; @@ -181,7 +172,7 @@ contract PILPolicyFrameworkTest is BaseTest { vm.expectRevert( PILFrameworkErrors.PILPolicyFrameworkManager__CommercialDisabled_CantAddCommercializers.selector ); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); // reset inputA.policy.commercializerChecker = address(0); @@ -190,7 +181,7 @@ contract PILPolicyFrameworkTest is BaseTest { // CHECK: No rev share should be set; revert inputA.policy.commercialRevShare = 1; vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__CommercialDisabled_CantAddRevShare.selector); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); // reset inputA.policy.commercialRevShare = 0; @@ -198,7 +189,7 @@ contract PILPolicyFrameworkTest is BaseTest { // CHECK: royaltyPolicy != address(0) should revert inputA.royaltyPolicy = address(0x123123); vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__CommercialDisabled_CantAddRoyaltyPolicy.selector); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); // reset inputA.royaltyPolicy = address(0); @@ -206,7 +197,7 @@ contract PILPolicyFrameworkTest is BaseTest { // CHECK: mintingFee > 0 should revert inputA.mintingFee = 100; vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__CommercialDisabled_CantAddMintingFee.selector); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); // reset inputA.mintingFee = 0; @@ -226,7 +217,7 @@ contract PILPolicyFrameworkTest is BaseTest { // CHECK: royaltyPolicy == address(0) should revert inputA.royaltyPolicy = address(0); vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__CommercialEnabled_RoyaltyPolicyRequired.selector); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); // reset inputA.royaltyPolicy = address(0x123123); @@ -244,8 +235,8 @@ contract PILPolicyFrameworkTest is BaseTest { inputA.policy.commercialAttribution = true; inputA.policy.commercializerChecker = address(0); inputA.policy.commercializerCheckerData = ""; - uint256 policyId = pilFramework.registerPolicy(inputA); - PILPolicy memory policy = pilFramework.getPILPolicy(policyId); + uint256 policyId = _pilFramework().registerPolicy(inputA); + PILPolicy memory policy = _pilFramework().getPILPolicy(policyId); assertEq(keccak256(abi.encode(policy)), keccak256(abi.encode(inputA.policy))); } @@ -283,12 +274,12 @@ contract PILPolicyFrameworkTest is BaseTest { invalidCommercializerChecker ) ); - pilFramework.registerPolicy(input); + _pilFramework().registerPolicy(input); input.policy.commercializerChecker = address(tokenGatedHook); input.policy.commercializerCheckerData = invalideCommercializerCheckerData; vm.expectRevert("MockTokenGatedHook: Invalid token address"); - pilFramework.registerPolicy(input); + _pilFramework().registerPolicy(input); } function test_PILPolicyFrameworkManager__derivatives_notAllowed_revert_settingIncompatibleTerms() public { @@ -304,17 +295,17 @@ contract PILPolicyFrameworkTest is BaseTest { inputA.policy.derivativesAttribution = true; // derivativesAttribution = true should revert vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__DerivativesDisabled_CantAddAttribution.selector); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); // Requesting approval for derivatives should revert inputA.policy.derivativesAttribution = false; inputA.policy.derivativesApproval = true; vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__DerivativesDisabled_CantAddApproval.selector); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); // Setting reciprocal license should revert inputA.policy.derivativesApproval = false; inputA.policy.derivativesReciprocal = true; vm.expectRevert(PILFrameworkErrors.PILPolicyFrameworkManager__DerivativesDisabled_CantAddReciprocal.selector); - pilFramework.registerPolicy(inputA); + _pilFramework().registerPolicy(inputA); } function test_PILPolicyFrameworkManager__derivatives_valuesSetCorrectly() public { @@ -327,8 +318,8 @@ contract PILPolicyFrameworkTest is BaseTest { }); RegisterPILPolicyParams memory inputA = _getMappedPilParams("pol_a"); inputA.policy.derivativesAttribution = true; - uint256 policyId = pilFramework.registerPolicy(inputA); - PILPolicy memory policy = pilFramework.getPILPolicy(policyId); + uint256 policyId = _pilFramework().registerPolicy(inputA); + PILPolicy memory policy = _pilFramework().getPILPolicy(policyId); assertEq(keccak256(abi.encode(policy)), keccak256(abi.encode(inputA.policy))); } @@ -346,16 +337,16 @@ contract PILPolicyFrameworkTest is BaseTest { }); RegisterPILPolicyParams memory inputA = _getMappedPilParams("pol_a"); inputA.policy.derivativesApproval = true; - uint256 policyId = pilFramework.registerPolicy(inputA); + uint256 policyId = _pilFramework().registerPolicy(inputA); vm.prank(alice); licensingModule.addPolicyToIp(ipId1, policyId); uint256 licenseId = licensingModule.mintLicense(policyId, ipId1, 1, alice, ""); - assertFalse(pilFramework.isDerivativeApproved(licenseId, ipId2)); + assertFalse(_pilFramework().isDerivativeApproved(licenseId, ipId2)); vm.prank(licenseRegistry.licensorIpId(licenseId)); - pilFramework.setApproval(licenseId, ipId2, false); - assertFalse(pilFramework.isDerivativeApproved(licenseId, ipId2)); + _pilFramework().setApproval(licenseId, ipId2, false); + assertFalse(_pilFramework().isDerivativeApproved(licenseId, ipId2)); uint256[] memory licenseIds = new uint256[](1); licenseIds[0] = licenseId; @@ -375,17 +366,17 @@ contract PILPolicyFrameworkTest is BaseTest { }); RegisterPILPolicyParams memory inputA = _getMappedPilParams("pol_a"); inputA.policy.derivativesApproval = true; - uint256 policyId = pilFramework.registerPolicy(inputA); + uint256 policyId = _pilFramework().registerPolicy(inputA); vm.prank(alice); licensingModule.addPolicyToIp(ipId1, policyId); uint256 licenseId = licensingModule.mintLicense(policyId, ipId1, 1, alice, ""); - assertFalse(pilFramework.isDerivativeApproved(licenseId, ipId2)); + assertFalse(_pilFramework().isDerivativeApproved(licenseId, ipId2)); vm.prank(licenseRegistry.licensorIpId(licenseId)); - pilFramework.setApproval(licenseId, ipId2, true); - assertTrue(pilFramework.isDerivativeApproved(licenseId, ipId2)); + _pilFramework().setApproval(licenseId, ipId2, true); + assertTrue(_pilFramework().isDerivativeApproved(licenseId, ipId2)); uint256[] memory licenseIds = new uint256[](1); licenseIds[0] = licenseId; @@ -409,7 +400,7 @@ contract PILPolicyFrameworkTest is BaseTest { }); RegisterPILPolicyParams memory inputA = _getMappedPilParams("pol_a"); inputA.transferable = true; - uint256 policyId = pilFramework.registerPolicy(inputA); + uint256 policyId = _pilFramework().registerPolicy(inputA); vm.prank(alice); licensingModule.addPolicyToIp(ipId1, policyId); uint256 licenseId = licensingModule.mintLicense(policyId, ipId1, 1, licenseHolder, ""); @@ -431,7 +422,7 @@ contract PILPolicyFrameworkTest is BaseTest { }); RegisterPILPolicyParams memory inputA = _getMappedPilParams("pol_a"); inputA.transferable = false; - uint256 policyId = pilFramework.registerPolicy(inputA); + uint256 policyId = _pilFramework().registerPolicy(inputA); vm.prank(alice); licensingModule.addPolicyToIp(ipId1, policyId); uint256 licenseId = licensingModule.mintLicense(policyId, ipId1, 1, licenseHolder, ""); @@ -466,7 +457,7 @@ contract PILPolicyFrameworkTest is BaseTest { contentRestrictions: emptyStringArray }); - string memory actualJson = pilFramework.policyToJson(abi.encode(policyData)); + string memory actualJson = _pilFramework().policyToJson(abi.encode(policyData)); /* solhint-disable */ string memory expectedJson = '{"trait_type": "Attribution", "value": "false"},{"trait_type": "Commercial Use", "value": "false"},{"trait_type": "Commercial Attribution", "value": "true"},{"trait_type": "Commercial Revenue Share", "max_value": 1000, "value": 0},{"trait_type": "Commercializer Check", "value": "0x0000000000000000000000000000000000000000"},{"trait_type": "Derivatives Allowed", "value": "true"},{"trait_type": "Derivatives Attribution", "value": "false"},{"trait_type": "Derivatives Approval", "value": "false"},{"trait_type": "Derivatives Reciprocal", "value": "false"},{"trait_type": "Territories", "value": ["test1","test2"]},{"trait_type": "Distribution Channels", "value": ["test3"]},'; diff --git a/test/foundry/utils/BaseTest.t.sol b/test/foundry/utils/BaseTest.t.sol index b8adfd266..105d9e87c 100644 --- a/test/foundry/utils/BaseTest.t.sol +++ b/test/foundry/utils/BaseTest.t.sol @@ -105,7 +105,7 @@ contract BaseTest is Test, DeployHelper, LicensingHelper { vm.startPrank(u.admin); // NOTE: accessController is IAccessController, which doesn't expose `initialize` function. - AccessController(address(accessController)).initialize(address(ipAccountRegistry), getModuleRegistry()); + AccessController(address(accessController)).setAddresses(address(ipAccountRegistry), getModuleRegistry()); accessController.setGlobalPermission( address(ipAssetRegistry), diff --git a/test/foundry/utils/DeployHelper.t.sol b/test/foundry/utils/DeployHelper.t.sol index e7e7219c2..fcf151cc4 100644 --- a/test/foundry/utils/DeployHelper.t.sol +++ b/test/foundry/utils/DeployHelper.t.sol @@ -44,7 +44,6 @@ import { MockERC20 } from "../mocks/token/MockERC20.sol"; import { MockERC721 } from "../mocks/token/MockERC721.sol"; import { TestProxyHelper } from "./TestProxyHelper.sol"; - contract DeployHelper { // TODO: three options, auto/mock/real in deploy condition, so that we don't need to manually // call getXXX to get mock contract (if there's no real contract deployed). @@ -220,7 +219,11 @@ contract DeployHelper { console2.log("DeployHelper: Using REAL Governance"); } if (d.accessController) { - accessController = new AccessController(getGovernance()); + address impl = address(new AccessController()); + accessController = AccessController( + TestProxyHelper.deployUUPSProxy(impl, abi.encodeCall(AccessController.initialize, (getGovernance()))) + ); + console2.log("DeployHelper: Using REAL AccessController"); postDeployConditions.accessController_init = true; // Access Controller uses IPAccountRegistry in its initialize function. @@ -235,7 +238,10 @@ contract DeployHelper { function _deployRegistryConditionally(DeployRegistryCondition memory d) public { if (d.moduleRegistry) { - moduleRegistry = new ModuleRegistry(getGovernance()); + address impl = address(new ModuleRegistry()); + moduleRegistry = ModuleRegistry( + TestProxyHelper.deployUUPSProxy(impl, abi.encodeCall(AccessController.initialize, (getGovernance()))) + ); console2.log("DeployHelper: Using REAL ModuleRegistry"); postDeployConditions.moduleRegistry_registerModules = true; } @@ -287,12 +293,17 @@ contract DeployHelper { } if (d.licensingModule) { require(address(ipAccountRegistry) != address(0), "DeployHelper Module: IPAccountRegistry required"); - licensingModule = new LicensingModule( - getAccessController(), - address(ipAccountRegistry), - getRoyaltyModule(), - getLicenseRegistry(), - getDisputeModule() + address impl = address( + new LicensingModule( + getAccessController(), + address(ipAccountRegistry), + getRoyaltyModule(), + getLicenseRegistry(), + getDisputeModule() + ) + ); + licensingModule = LicensingModule( + TestProxyHelper.deployUUPSProxy(impl, abi.encodeCall(LicensingModule.initialize, (getGovernance()))) ); console2.log("DeployHelper: Using REAL LicensingModule"); } diff --git a/test/foundry/utils/LicensingHelper.t.sol b/test/foundry/utils/LicensingHelper.t.sol index f271b86b9..6703bf987 100644 --- a/test/foundry/utils/LicensingHelper.t.sol +++ b/test/foundry/utils/LicensingHelper.t.sol @@ -9,13 +9,10 @@ import { IIPAccountRegistry } from "../../../contracts/interfaces/registries/IIP import { ILicensingModule } from "../../../contracts/interfaces/modules/licensing/ILicensingModule.sol"; import { IRoyaltyModule } from "../../../contracts/interfaces/modules/royalty/IRoyaltyModule.sol"; import { IRoyaltyPolicyLAP } from "../../../contracts/interfaces/modules/royalty/policies/IRoyaltyPolicyLAP.sol"; -import { BasePolicyFrameworkManager } from "../../../contracts/modules/licensing/BasePolicyFrameworkManager.sol"; // solhint-disable-next-line max-line-length import { PILPolicyFrameworkManager, PILPolicy, RegisterPILPolicyParams } from "../../../contracts/modules/licensing/PILPolicyFrameworkManager.sol"; - +import { TestProxyHelper } from "./TestProxyHelper.sol"; // test -// solhint-disable-next-line max-line-length -import { MockPolicyFrameworkManager, MockPolicyFrameworkConfig } from "test/foundry/mocks/licensing/MockPolicyFrameworkManager.sol"; contract LicensingHelper { IAccessController private ACCESS_CONTROLLER; // keep private to avoid collision with `BaseIntegration` @@ -61,7 +58,18 @@ contract LicensingHelper { //////////////////////////////////////////////////////////////////////////*/ modifier withLFM_PIL() { - _deployLFM_PIL(); + _setPILPolicyFrameworkManager(); + _; + } + + modifier withPILPolicySimple( + string memory name, + bool commercial, + bool derivatives, + bool reciprocal + ) { + _mapPILPolicySimple(name, commercial, derivatives, reciprocal, 100); + _addPILPolicyFromMapping(name, address(_pilFramework())); _; } @@ -70,15 +78,25 @@ contract LicensingHelper { //////////////////////////////////////////////////////////////////////////*/ function _setPILPolicyFrameworkManager() internal { - PILPolicyFrameworkManager pilPfm = new PILPolicyFrameworkManager( + _deployPILFramework("license Url"); + LICENSING_MODULE.registerPolicyFrameworkManager(pfm["pil"]); + } + + function _deployPILFramework(string memory licenseUrl) internal returns (address) { + PILPolicyFrameworkManager impl = new PILPolicyFrameworkManager( address(ACCESS_CONTROLLER), address(IP_ACCOUNT_REGISTRY), - address(LICENSING_MODULE), - "PIL_MINT_PAYMENT", - "license Url" + address(LICENSING_MODULE) + ); + pfm["pil"] = TestProxyHelper.deployUUPSProxy( + address(impl), + abi.encodeCall(PILPolicyFrameworkManager.initialize, ("PIL_MINT_PAYMENT", licenseUrl)) ); - pfm["pil"] = address(pilPfm); - LICENSING_MODULE.registerPolicyFrameworkManager(address(pilPfm)); + return pfm["pil"]; + } + + function _pilFramework() internal view returns (PILPolicyFrameworkManager) { + return PILPolicyFrameworkManager(pfm["pil"]); } function _addPILPolicy( @@ -207,39 +225,4 @@ contract LicensingHelper { string memory pName = string(abi.encodePacked("pil_", name)); return policyIds[pName]; } - - function _createMockPolicyFrameworkManager( - bool supportVerifyLink, - bool supportVerifyMint - ) private returns (BasePolicyFrameworkManager) { - return - BasePolicyFrameworkManager( - new MockPolicyFrameworkManager( - MockPolicyFrameworkConfig({ - licensingModule: address(LICENSING_MODULE), - name: "mock", - licenseUrl: "license url", - royaltyPolicy: address(0xdeadbeef) - }) - ) - ); - } - - function _deployLFM_PIL() internal { - BasePolicyFrameworkManager _pfm = BasePolicyFrameworkManager( - new PILPolicyFrameworkManager( - address(ACCESS_CONTROLLER), - address(IP_ACCOUNT_REGISTRY), - address(LICENSING_MODULE), - "pil", - "license Url" - ) - ); - LICENSING_MODULE.registerPolicyFrameworkManager(address(_pfm)); - pfm["pil"] = address(_pfm); - } - - function _pilFramework() internal view returns (PILPolicyFrameworkManager) { - return PILPolicyFrameworkManager(pfm["pil"]); - } }