From 5fdbc74d8a67ab4d4304072ad5684f82e0c2c51c Mon Sep 17 00:00:00 2001 From: Sebastian Liu Date: Wed, 14 Aug 2024 11:11:00 -0700 Subject: [PATCH 1/2] fix(register-derivatives): collect and approve mint fees for commercial licenses --- contracts/SPGNFT.sol | 1 + contracts/StoryProtocolGateway.sol | 141 +++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) diff --git a/contracts/SPGNFT.sol b/contracts/SPGNFT.sol index 24f40a8..1d2bf8c 100644 --- a/contracts/SPGNFT.sol +++ b/contracts/SPGNFT.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.23; import { AccessControlUpgradeable } from "@openzeppelin/contracts-upgradeable/access/AccessControlUpgradeable.sol"; +// solhint-disable-next-line max-line-length import { ERC721URIStorageUpgradeable } from "@openzeppelin/contracts-upgradeable/token/ERC721/extensions/ERC721URIStorageUpgradeable.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { IERC165 } from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; diff --git a/contracts/StoryProtocolGateway.sol b/contracts/StoryProtocolGateway.sol index 2a0e12d..c226609 100644 --- a/contracts/StoryProtocolGateway.sol +++ b/contracts/StoryProtocolGateway.sol @@ -13,9 +13,17 @@ import { IAccessController } from "@storyprotocol/core/interfaces/access/IAccess // solhint-disable-next-line max-line-length import { IPILicenseTemplate, PILTerms } from "@storyprotocol/core/interfaces/modules/licensing/IPILicenseTemplate.sol"; import { ILicensingModule } from "@storyprotocol/core/interfaces/modules/licensing/ILicensingModule.sol"; +import { ILicenseTemplate } from "@storyprotocol/core/interfaces/modules/licensing/ILicenseTemplate.sol"; +import { ILicenseRegistry } from "@storyprotocol/core/interfaces/registries/ILicenseRegistry.sol"; +import { ILicensingHook } from "@storyprotocol/core/interfaces/modules/licensing/ILicensingHook.sol"; +import { Licensing } from "@storyprotocol/core/lib/Licensing.sol"; +import { IRoyaltyModule } from "@storyprotocol/core/interfaces/modules/royalty/IRoyaltyModule.sol"; + import { ICoreMetadataModule } from "@storyprotocol/core/interfaces/modules/metadata/ICoreMetadataModule.sol"; import { IIPAssetRegistry } from "@storyprotocol/core/interfaces/registries/IIPAssetRegistry.sol"; import { AccessPermission } from "@storyprotocol/core/lib/AccessPermission.sol"; +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import { IStoryProtocolGateway } from "./interfaces/IStoryProtocolGateway.sol"; import { ISPGNFT } from "./interfaces/ISPGNFT.sol"; @@ -24,6 +32,7 @@ import { SPGNFTLib } from "./lib/SPGNFTLib.sol"; contract StoryProtocolGateway is IStoryProtocolGateway, AccessManagedUpgradeable, UUPSUpgradeable { using ERC165Checker for address; + using SafeERC20 for IERC20; /// @dev Storage structure for the SPG /// @param nftContractBeacon The address of the NFT contract beacon. @@ -44,6 +53,12 @@ contract StoryProtocolGateway is IStoryProtocolGateway, AccessManagedUpgradeable /// @notice The address of the Licensing Module. ILicensingModule public immutable LICENSING_MODULE; + /// @notice The address of the License Registry. + ILicenseRegistry public immutable LICENSE_REGISTRY; + + /// @notice The address of the Royalty Module. + IRoyaltyModule public immutable ROYALTY_MODULE; + /// @notice The address of the Core Metadata Module. ICoreMetadataModule public immutable CORE_METADATA_MODULE; @@ -65,6 +80,8 @@ contract StoryProtocolGateway is IStoryProtocolGateway, AccessManagedUpgradeable address accessController, address ipAssetRegistry, address licensingModule, + address licenseRegistry, + address royaltyModule, address coreMetadataModule, address pilTemplate, address licenseToken @@ -73,6 +90,8 @@ contract StoryProtocolGateway is IStoryProtocolGateway, AccessManagedUpgradeable accessController == address(0) || ipAssetRegistry == address(0) || licensingModule == address(0) || + licenseRegistry == address(0) || + royaltyModule == address(0) || coreMetadataModule == address(0) || licenseToken == address(0) ) revert Errors.SPG__ZeroAddressParam(); @@ -80,6 +99,8 @@ contract StoryProtocolGateway is IStoryProtocolGateway, AccessManagedUpgradeable ACCESS_CONTROLLER = IAccessController(accessController); IP_ASSET_REGISTRY = IIPAssetRegistry(ipAssetRegistry); LICENSING_MODULE = ILicensingModule(licensingModule); + LICENSE_REGISTRY = ILicenseRegistry(licenseRegistry); + ROYALTY_MODULE = IRoyaltyModule(royaltyModule); CORE_METADATA_MODULE = ICoreMetadataModule(coreMetadataModule); PIL_TEMPLATE = IPILicenseTemplate(pilTemplate); LICENSE_TOKEN = ILicenseToken(licenseToken); @@ -254,6 +275,14 @@ contract StoryProtocolGateway is IStoryProtocolGateway, AccessManagedUpgradeable ipId = IP_ASSET_REGISTRY.register(block.chainid, nftContract, tokenId); _setMetadata(ipMetadata, ipId); + _collectMintFeesAndSetApproval( + msg.sender, + ipId, + derivData.parentIpIds, + derivData.licenseTemplate, + derivData.licenseTermsIds + ); + LICENSING_MODULE.registerDerivative({ childIpId: ipId, parentIpIds: derivData.parentIpIds, @@ -289,6 +318,15 @@ contract StoryProtocolGateway is IStoryProtocolGateway, AccessManagedUpgradeable address(LICENSING_MODULE), ILicensingModule.registerDerivative.selector ); + + _collectMintFeesAndSetApproval( + msg.sender, + ipId, + derivData.parentIpIds, + derivData.licenseTemplate, + derivData.licenseTermsIds + ); + LICENSING_MODULE.registerDerivative({ childIpId: ipId, parentIpIds: derivData.parentIpIds, @@ -441,6 +479,109 @@ contract StoryProtocolGateway is IStoryProtocolGateway, AccessManagedUpgradeable _setMetadata(ipMetadata, ipId); } + /// @dev Collect mint fees for all parent IPs from the payer and set approval for Royalty Module to spend mint fees. + /// @param payerAddress The address of the payer for the license mint fees. + /// @param childIpId The ID of the derivative IP. + /// @param parentIpIds The IDs of all the parent IPs. + /// @param licenseTemplate The address of the license template. + /// @param licenseTermsIds The IDs of the license terms for each corresponding parent IP. + function _collectMintFeesAndSetApproval( + address payerAddress, + address childIpId, + address[] calldata parentIpIds, + address licenseTemplate, + uint256[] calldata licenseTermsIds + ) private { + // Get currency token and royalty policy, assumes all parent IPs have the same currency token. + ILicenseTemplate lct = ILicenseTemplate(licenseTemplate); + (address royaltyPolicy, , , address mintFeeCurrencyToken) = lct.getRoyaltyPolicy(licenseTermsIds[0]); + + if (royaltyPolicy != address(0)) { + // Get total mint fee for all parent IPs + uint256 totalMintFee = _aggregateMintFees(parentIpIds, childIpId, licenseTemplate, licenseTermsIds); + + if (totalMintFee != 0) { + // Transfer mint fee from payer to this contract + IERC20(mintFeeCurrencyToken).safeTransferFrom(payerAddress, address(this), totalMintFee); + + // Approve Royalty Policy to spend mint fee + IERC20(mintFeeCurrencyToken).forceApprove(royaltyPolicy, totalMintFee); + } + } + } + + /// @dev Aggregate license mint fees for all parent IPs. + /// @param parentIpIds The IDs of all the parent IPs. + /// @param childIpId The ID of the derivative IP. + /// @param licenseTemplate The address of the license template. + /// @param licenseTermsIds The IDs of the license terms for each corresponding parent IP. + /// @return totalMintFee The sum of license mint fees across all parent IPs. + function _aggregateMintFees( + address[] calldata parentIpIds, + address childIpId, + address licenseTemplate, + uint256[] calldata licenseTermsIds + ) private returns (uint256 totalMintFee) { + totalMintFee = 0; + + for (uint256 i = 0; i < parentIpIds.length; i++) { + totalMintFee += _getMintFeeForSingleParent( + childIpId, + parentIpIds[i], + licenseTemplate, + licenseTermsIds[i], + 1 + ); + } + } + + /// @dev Fetch the license token mint fee from the licensing hook or license terms for the given parent IP. + /// @param childIpId The ID of the derivative IP. + /// @param parentIpId The ID of the parent IP. + /// @param licenseTemplate The address of the license template. + /// @param licenseTermsId The ID of the license terms for the parent IP. + /// @param amount The amount of licenses to mint. + function _getMintFeeForSingleParent( + address childIpId, + address parentIpId, + address licenseTemplate, + uint256 licenseTermsId, + uint256 amount + ) private returns (uint256) { + ILicenseTemplate lct = ILicenseTemplate(licenseTemplate); + + // Get mint fee set by license terms + (address royaltyPolicy, , uint256 mintFeeSetByLicenseTerms, ) = lct.getRoyaltyPolicy(licenseTermsId); + + // If no royalty policy, return 0 + if (royaltyPolicy == address(0)) return 0; + + uint256 mintFeeSetByHook = 0; + + Licensing.LicensingConfig memory licensingConfig = LICENSE_REGISTRY.getLicensingConfig( + parentIpId, + licenseTemplate, + licenseTermsId + ); + + // Get mint fee from licensing hook + if (licensingConfig.licensingHook != address(0)) { + mintFeeSetByHook = ILicensingHook(licensingConfig.licensingHook).beforeRegisterDerivative( + address(this), + childIpId, + parentIpId, + licenseTemplate, + licenseTermsId, + licensingConfig.hookData + ); + } + + if (!licensingConfig.isSet) return mintFeeSetByLicenseTerms * amount; + if (licensingConfig.licensingHook == address(0)) return licensingConfig.mintingFee * amount; + + return mintFeeSetByHook; + } + // // Upgrade // From 19f0041a0f4b9fc54dcfafbade35d06eadac2823 Mon Sep 17 00:00:00 2001 From: Sebastian Liu Date: Wed, 14 Aug 2024 19:22:24 -0700 Subject: [PATCH 2/2] test(register-derivatives): add tests for commercial license parent derivatives --- script/Main.s.sol | 2 + .../utils/StoryProtocolCoreAddressManager.sol | 4 + test/StoryProtocolGateway.t.sol | 254 ++++++++++-------- test/utils/BaseTest.t.sol | 50 +++- 4 files changed, 204 insertions(+), 106 deletions(-) diff --git a/script/Main.s.sol b/script/Main.s.sol index 4755598..e46322a 100644 --- a/script/Main.s.sol +++ b/script/Main.s.sol @@ -57,6 +57,8 @@ contract Main is Script, StoryProtocolCoreAddressManager, BroadcastManager, Json accessControllerAddr, ipAssetRegistryAddr, licensingModuleAddr, + licenseRegistryAddr, + royaltyModuleAddr, coreMetadataModuleAddr, pilTemplateAddr, licenseTokenAddr diff --git a/script/utils/StoryProtocolCoreAddressManager.sol b/script/utils/StoryProtocolCoreAddressManager.sol index 5d0affb..b0c4ba5 100644 --- a/script/utils/StoryProtocolCoreAddressManager.sol +++ b/script/utils/StoryProtocolCoreAddressManager.sol @@ -10,6 +10,8 @@ contract StoryProtocolCoreAddressManager is Script { address internal protocolAccessManagerAddr; address internal ipAssetRegistryAddr; address internal licensingModuleAddr; + address internal licenseRegistryAddr; + address internal royaltyModuleAddr; address internal coreMetadataModuleAddr; address internal accessControllerAddr; address internal pilTemplateAddr; @@ -31,6 +33,8 @@ contract StoryProtocolCoreAddressManager is Script { protocolAccessManagerAddr = json.readAddress(".main.ProtocolAccessManager"); ipAssetRegistryAddr = json.readAddress(".main.IPAssetRegistry"); licensingModuleAddr = json.readAddress(".main.LicensingModule"); + licenseRegistryAddr = json.readAddress(".main.LicenseRegistry"); + royaltyModuleAddr = json.readAddress(".main.RoyaltyModule"); coreMetadataModuleAddr = json.readAddress(".main.CoreMetadataModule"); accessControllerAddr = json.readAddress(".main.AccessController"); pilTemplateAddr = json.readAddress(".main.PILicenseTemplate"); diff --git a/test/StoryProtocolGateway.t.sol b/test/StoryProtocolGateway.t.sol index 566b008..854adb3 100644 --- a/test/StoryProtocolGateway.t.sol +++ b/test/StoryProtocolGateway.t.sol @@ -205,8 +205,9 @@ contract StoryProtocolGatewayTest is BaseTest { modifier withEnoughTokens() { require(caller != address(0), "withEnoughTokens: caller not set"); - mockToken.mint(address(caller), 1000 * 10 ** mockToken.decimals()); + mockToken.mint(address(caller), 2000 * 10 ** mockToken.decimals()); mockToken.approve(address(nftContract), 1000 * 10 ** mockToken.decimals()); + mockToken.approve(address(spg), 1000 * 10 ** mockToken.decimals()); _; } @@ -280,7 +281,7 @@ contract StoryProtocolGatewayTest is BaseTest { }); } - modifier withParentIp() { + modifier withNonCommercialParentIp() { (ipIdParent, , ) = spg.mintAndRegisterIpAndAttachPILTerms({ nftContract: address(nftContract), recipient: caller, @@ -291,126 +292,59 @@ contract StoryProtocolGatewayTest is BaseTest { _; } - function test_SPG_mintAndRegisterIpAndMakeDerivative() + function test_SPG_mintAndRegisterIpAndMakeDerivativeWithNonCommercialLicense() public withCollection whenCallerHasMinterRole withEnoughTokens - withParentIp + withNonCommercialParentIp { - (address licenseTemplateParent, uint256 licenseTermsIdParent) = licenseRegistry.getAttachedLicenseTerms( - ipIdParent, - 0 - ); - - address[] memory parentIpIds = new address[](1); - parentIpIds[0] = ipIdParent; + _mintAndRegisterIpAndMakeDerivativeBaseTest(); + } - uint256[] memory licenseTermsIds = new uint256[](1); - licenseTermsIds[0] = licenseTermsIdParent; + function test_SPG_registerIpAndMakeDerivativeWithNonCommercialLicense() + public + withCollection + whenCallerHasMinterRole + withEnoughTokens + withNonCommercialParentIp + { + _registerIpAndMakeDerivativeBaseTest(); + } - (address ipIdChild, uint256 tokenIdChild) = spg.mintAndRegisterIpAndMakeDerivative({ + modifier withCommercialParentIp() { + (ipIdParent, , ) = spg.mintAndRegisterIpAndAttachPILTerms({ nftContract: address(nftContract), - derivData: ISPG.MakeDerivative({ - parentIpIds: parentIpIds, - licenseTemplate: address(pilTemplate), - licenseTermsIds: licenseTermsIds, - royaltyContext: "" - }), + recipient: caller, nftMetadata: nftMetadataDefault, ipMetadata: ipMetadataDefault, - recipient: caller - }); - assertTrue(ipAssetRegistry.isRegistered(ipIdChild)); - assertEq(tokenIdChild, 2); - assertMetadata(ipIdChild, ipMetadataDefault); - assertMetadata(ipIdChild, ipMetadataDefault); - (address licenseTemplateChild, uint256 licenseTermsIdChild) = licenseRegistry.getAttachedLicenseTerms( - ipIdChild, - 0 - ); - assertEq(licenseTemplateChild, licenseTemplateParent); - assertEq(licenseTermsIdChild, licenseTermsIdParent); - assertEq(IIPAccount(payable(ipIdChild)).owner(), caller); - - assertParentChild({ - ipIdParent: ipIdParent, - ipIdChild: ipIdChild, - expectedParentCount: 1, - expectedParentIndex: 0 + terms: PILFlavors.commercialUse({ + mintingFee: 100 * 10 ** mockToken.decimals(), + currencyToken: address(mockToken), + royaltyPolicy: address(royaltyPolicyLAP) + }) }); + _; } - function test_SPG_registerIpAndMakeDerivative() + function test_SPG_mintAndRegisterIpAndMakeDerivativeWithCommercialLicense() public withCollection whenCallerHasMinterRole withEnoughTokens - withParentIp + withCommercialParentIp { - (address licenseTemplateParent, uint256 licenseTermsIdParent) = licenseRegistry.getAttachedLicenseTerms( - ipIdParent, - 0 - ); - - uint256 tokenIdChild = nftContract.mint(address(caller), nftMetadataEmpty); - address ipIdChild = ipAssetRegistry.ipId(block.chainid, address(nftContract), tokenIdChild); - - uint256 deadline = block.timestamp + 1000; - - (bytes memory sigMetadata, ) = _getSetPermissionSignatureForSPG({ - ipId: ipIdChild, - module: address(coreMetadataModule), - selector: ICoreMetadataModule.setAll.selector, - deadline: deadline, - nonce: 1, - signerPk: alicePk - }); - (bytes memory sigRegister, ) = _getSetPermissionSignatureForSPG({ - ipId: ipIdChild, - module: address(licensingModule), - selector: ILicensingModule.registerDerivative.selector, - deadline: deadline, - nonce: 2, - signerPk: alicePk - }); - - address[] memory parentIpIds = new address[](1); - parentIpIds[0] = ipIdParent; - - uint256[] memory licenseTermsIds = new uint256[](1); - licenseTermsIds[0] = licenseTermsIdParent; - - address ipIdChildActual = spg.registerIpAndMakeDerivative({ - nftContract: address(nftContract), - tokenId: tokenIdChild, - derivData: ISPG.MakeDerivative({ - parentIpIds: parentIpIds, - licenseTemplate: address(pilTemplate), - licenseTermsIds: licenseTermsIds, - royaltyContext: "" - }), - ipMetadata: ipMetadataDefault, - sigMetadata: ISPG.SignatureData({ signer: alice, deadline: deadline, signature: sigMetadata }), - sigRegister: ISPG.SignatureData({ signer: alice, deadline: deadline, signature: sigRegister }) - }); - assertEq(ipIdChildActual, ipIdChild); - assertTrue(ipAssetRegistry.isRegistered(ipIdChild)); - assertMetadata(ipIdChild, ipMetadataDefault); - (address licenseTemplateChild, uint256 licenseTermsIdChild) = licenseRegistry.getAttachedLicenseTerms( - ipIdChild, - 0 - ); - assertEq(licenseTemplateChild, licenseTemplateParent); - assertEq(licenseTermsIdChild, licenseTermsIdParent); - assertEq(IIPAccount(payable(ipIdChild)).owner(), caller); + _mintAndRegisterIpAndMakeDerivativeBaseTest(); + } - assertParentChild({ - ipIdParent: ipIdParent, - ipIdChild: ipIdChild, - expectedParentCount: 1, - expectedParentIndex: 0 - }); + function test_SPG_registerIpAndMakeDerivativeWithCommercialLicense() + public + withCollection + whenCallerHasMinterRole + withEnoughTokens + withCommercialParentIp + { + _registerIpAndMakeDerivativeBaseTest(); } function test_SPG_mintAndRegisterIpAndMakeDerivativeWithLicenseTokens() @@ -418,7 +352,7 @@ contract StoryProtocolGatewayTest is BaseTest { withCollection whenCallerHasMinterRole withEnoughTokens - withParentIp + withNonCommercialParentIp { (address licenseTemplateParent, uint256 licenseTermsIdParent) = licenseRegistry.getAttachedLicenseTerms( ipIdParent, @@ -473,7 +407,7 @@ contract StoryProtocolGatewayTest is BaseTest { withCollection whenCallerHasMinterRole withEnoughTokens - withParentIp + withNonCommercialParentIp { caller = alice; vm.startPrank(caller); @@ -611,4 +545,114 @@ contract StoryProtocolGatewayTest is BaseTest { (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerPk, digest); signature = abi.encodePacked(r, s, v); } + + function _mintAndRegisterIpAndMakeDerivativeBaseTest() internal { + (address licenseTemplateParent, uint256 licenseTermsIdParent) = licenseRegistry.getAttachedLicenseTerms( + ipIdParent, + 0 + ); + + address[] memory parentIpIds = new address[](1); + parentIpIds[0] = ipIdParent; + + uint256[] memory licenseTermsIds = new uint256[](1); + licenseTermsIds[0] = licenseTermsIdParent; + + (address ipIdChild, uint256 tokenIdChild) = spg.mintAndRegisterIpAndMakeDerivative({ + nftContract: address(nftContract), + derivData: ISPG.MakeDerivative({ + parentIpIds: parentIpIds, + licenseTemplate: address(pilTemplate), + licenseTermsIds: licenseTermsIds, + royaltyContext: "" + }), + nftMetadata: nftMetadataDefault, + ipMetadata: ipMetadataDefault, + recipient: caller + }); + assertTrue(ipAssetRegistry.isRegistered(ipIdChild)); + assertEq(tokenIdChild, 2); + assertMetadata(ipIdChild, ipMetadataDefault); + assertMetadata(ipIdChild, ipMetadataDefault); + (address licenseTemplateChild, uint256 licenseTermsIdChild) = licenseRegistry.getAttachedLicenseTerms( + ipIdChild, + 0 + ); + assertEq(licenseTemplateChild, licenseTemplateParent); + assertEq(licenseTermsIdChild, licenseTermsIdParent); + assertEq(IIPAccount(payable(ipIdChild)).owner(), caller); + + assertParentChild({ + ipIdParent: ipIdParent, + ipIdChild: ipIdChild, + expectedParentCount: 1, + expectedParentIndex: 0 + }); + } + + function _registerIpAndMakeDerivativeBaseTest() internal { + (address licenseTemplateParent, uint256 licenseTermsIdParent) = licenseRegistry.getAttachedLicenseTerms( + ipIdParent, + 0 + ); + + uint256 tokenIdChild = nftContract.mint(address(caller), nftMetadataEmpty); + address ipIdChild = ipAssetRegistry.ipId(block.chainid, address(nftContract), tokenIdChild); + + uint256 deadline = block.timestamp + 1000; + + (bytes memory sigMetadata, ) = _getSetPermissionSignatureForSPG({ + ipId: ipIdChild, + module: address(coreMetadataModule), + selector: ICoreMetadataModule.setAll.selector, + deadline: deadline, + nonce: 1, + signerPk: alicePk + }); + (bytes memory sigRegister, ) = _getSetPermissionSignatureForSPG({ + ipId: ipIdChild, + module: address(licensingModule), + selector: ILicensingModule.registerDerivative.selector, + deadline: deadline, + nonce: 2, + signerPk: alicePk + }); + + address[] memory parentIpIds = new address[](1); + parentIpIds[0] = ipIdParent; + + uint256[] memory licenseTermsIds = new uint256[](1); + licenseTermsIds[0] = licenseTermsIdParent; + + address ipIdChildActual = spg.registerIpAndMakeDerivative({ + nftContract: address(nftContract), + tokenId: tokenIdChild, + derivData: ISPG.MakeDerivative({ + parentIpIds: parentIpIds, + licenseTemplate: address(pilTemplate), + licenseTermsIds: licenseTermsIds, + royaltyContext: "" + }), + ipMetadata: ipMetadataDefault, + sigMetadata: ISPG.SignatureData({ signer: alice, deadline: deadline, signature: sigMetadata }), + sigRegister: ISPG.SignatureData({ signer: alice, deadline: deadline, signature: sigRegister }) + }); + assertEq(ipIdChildActual, ipIdChild); + assertTrue(ipAssetRegistry.isRegistered(ipIdChild)); + assertMetadata(ipIdChild, ipMetadataDefault); + (address licenseTemplateChild, uint256 licenseTermsIdChild) = licenseRegistry.getAttachedLicenseTerms( + ipIdChild, + 0 + ); + assertEq(licenseTemplateChild, licenseTemplateParent); + assertEq(licenseTermsIdChild, licenseTermsIdParent); + assertEq(IIPAccount(payable(ipIdChild)).owner(), caller); + + assertParentChild({ + ipIdParent: ipIdParent, + ipIdChild: ipIdChild, + expectedParentCount: 1, + expectedParentIndex: 0 + }); + } } diff --git a/test/utils/BaseTest.t.sol b/test/utils/BaseTest.t.sol index 0d1c40c..a57a83d 100644 --- a/test/utils/BaseTest.t.sol +++ b/test/utils/BaseTest.t.sol @@ -16,6 +16,8 @@ import { PILicenseTemplate } from "@storyprotocol/core/modules/licensing/PILicen import { LicensingModule } from "@storyprotocol/core/modules/licensing/LicensingModule.sol"; import { DisputeModule } from "@storyprotocol/core/modules/dispute/DisputeModule.sol"; import { RoyaltyModule } from "@storyprotocol/core/modules/royalty/RoyaltyModule.sol"; +import { RoyaltyPolicyLAP } from "@storyprotocol/core/modules/royalty/policies/RoyaltyPolicyLAP.sol"; +import { IpRoyaltyVault } from "@storyprotocol/core/modules/royalty/policies/IpRoyaltyVault.sol"; import { CoreMetadataModule } from "@storyprotocol/core/modules/metadata/CoreMetadataModule.sol"; import { CoreMetadataViewModule } from "@storyprotocol/core/modules/metadata/CoreMetadataViewModule.sol"; @@ -39,6 +41,10 @@ contract BaseTest is Test { IPAssetRegistry internal ipAssetRegistry; LicenseRegistry internal licenseRegistry; LicensingModule internal licensingModule; + RoyaltyModule internal royaltyModule; + RoyaltyPolicyLAP internal royaltyPolicyLAP; + UpgradeableBeacon internal ipRoyaltyVaultBeacon; + IpRoyaltyVault internal ipRoyaltyVaultImpl; CoreMetadataModule internal coreMetadataModule; CoreMetadataViewModule internal coreMetadataViewModule; PILicenseTemplate internal pilTemplate; @@ -195,7 +201,7 @@ contract BaseTest is Test { address(licenseRegistry) ) ); - RoyaltyModule royaltyModule = RoyaltyModule( + royaltyModule = RoyaltyModule( TestProxyHelper.deployUUPSProxy( create3Deployer, _getSalt(type(RoyaltyModule).name), @@ -234,6 +240,41 @@ contract BaseTest is Test { ); require(_loadProxyImpl(address(licensingModule)) == impl, "LicensingModule Proxy Implementation Mismatch"); + impl = address(new RoyaltyPolicyLAP(address(royaltyModule), address(licensingModule))); + royaltyPolicyLAP = RoyaltyPolicyLAP( + TestProxyHelper.deployUUPSProxy( + create3Deployer, + _getSalt(type(RoyaltyPolicyLAP).name), + impl, + abi.encodeCall(RoyaltyPolicyLAP.initialize, address(protocolAccessManager)) + ) + ); + require( + _getDeployedAddress(type(RoyaltyPolicyLAP).name) == address(royaltyPolicyLAP), + "Deploy: Royalty Policy LAP Address Mismatch" + ); + require(_loadProxyImpl(address(royaltyPolicyLAP)) == impl, "RoyaltyPolicyLAP Proxy Implementation Mismatch"); + + ipRoyaltyVaultImpl = IpRoyaltyVault( + create3Deployer.deploy( + _getSalt(type(IpRoyaltyVault).name), + abi.encodePacked( + type(IpRoyaltyVault).creationCode, + abi.encode(address(royaltyPolicyLAP), address(disputeModule)) + ) + ) + ); + + ipRoyaltyVaultBeacon = UpgradeableBeacon( + create3Deployer.deploy( + _getSalt("ipRoyaltyVaultBeacon"), + abi.encodePacked( + type(UpgradeableBeacon).creationCode, + abi.encode(address(ipRoyaltyVaultImpl), deployer) + ) + ) + ); + impl = address(new LicenseToken(address(licensingModule), address(disputeModule))); licenseToken = LicenseToken( TestProxyHelper.deployUUPSProxy( @@ -312,6 +353,10 @@ contract BaseTest is Test { coreMetadataViewModule.updateCoreMetadataModule(); licenseRegistry.registerLicenseTemplate(address(pilTemplate)); + + royaltyModule.whitelistRoyaltyPolicy(address(royaltyPolicyLAP), true); + royaltyPolicyLAP.setIpRoyaltyVaultBeacon(address(ipRoyaltyVaultBeacon)); + ipRoyaltyVaultBeacon.transferOwnership(address(royaltyPolicyLAP)); } function setUp_test_Periphery() public { @@ -320,6 +365,8 @@ contract BaseTest is Test { address(accessController), address(ipAssetRegistry), address(licensingModule), + address(licenseRegistry), + address(royaltyModule), address(coreMetadataModule), address(pilTemplate), address(licenseToken) @@ -358,6 +405,7 @@ contract BaseTest is Test { function setUp_test_Misc() public { mockToken = new MockERC20(); + royaltyModule.whitelistRoyaltyToken(address(mockToken), true); vm.label(alice, "Alice"); vm.label(bob, "Bob");