Skip to content

Commit

Permalink
✨ feat(random_prompt): Update hair color and style probabilities
Browse files Browse the repository at this point in the history
- Adjusted the probabilities for hair color and style choices in the
  random_prompt generator.
- Decreased the probability of selecting hair color to 0.1.
- Decreased the probability of selecting back hair style to 0.1.
- Increased the probability of selecting hair features to 0.12.
- Updated the probabilities for NSFW tags related to gender and body parts.

See the commit details for more information.
  • Loading branch information
sudoskys committed Feb 7, 2024
1 parent bc53c26 commit b3e31d2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
14 changes: 10 additions & 4 deletions playground/random_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
# @Author : sudoskys
# @File : random_prompt.py
# @Software: PyCharm
import random

from novelai_python.utils.random_prompt import RandomPromptGenerator

print(random.random())
s = RandomPromptGenerator(nsfw_enabled=True).generate()
print(s)
gen = RandomPromptGenerator(nsfw_enabled=True)
print(gen.get_weighted_choice([[1, 35], [2, 20], [3, 7]], []))
print("====")
print(gen.get_weighted_choice([['mss', 30], ['fdd', 50], ['oa', 10]], []))
print("====")
print(gen.get_weighted_choice([['m', 30], ['f', 50], ['o', 10]], ['m']))
print("====")
for i in range(200):
s = RandomPromptGenerator(nsfw_enabled=True).generate()
print(s)
15 changes: 12 additions & 3 deletions src/novelai_python/utils/random_prompt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@ def get_weighted_choice(tags, existing_tags: List[str]):
:param existing_tags: a list of existing tags
:return: a tag
"""
valid_tags = [tag for tag in tags if
len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])]
valid_tags = [tag
for tag in tags
if len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])]
total_weight = sum(tagr[1] for tagr in valid_tags if len(tagr) > 1)
if total_weight == 0:
return random.choice(tags)
if isinstance(tags, list):
rd = random.choice(tags)
elif isinstance(tags, str):
rd = tags
else:
raise ValueError('get_weighted_choice: should not reach here')
return rd
random_number = random.randint(1, total_weight)
cumulative_weight = 0
for tag in valid_tags:
cumulative_weight += tag[1]
if random_number <= cumulative_weight:
if isinstance(tag, str):
raise Exception("tag is string")
return tag[0]
raise ValueError('get_weighted_choice: should not reach here')

Expand Down

0 comments on commit b3e31d2

Please sign in to comment.