-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerate_prompts.py
125 lines (108 loc) · 3.05 KB
/
generate_prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
This component generates a set of initial prompts that will be used to retrieve images
from the LAION-5B dataset.
"""
import typing as t
import dask.dataframe as dd
import pandas as pd
import pyarrow as pa
from fondant.component import DaskLoadComponent
from fondant.pipeline import lightweight_component
@lightweight_component(produces={"prompt": pa.string()})
class GeneratePromptsComponent(DaskLoadComponent):
interior_styles = [
"art deco",
"bauhaus",
"bouclé",
"maximalist",
"brutalist",
"coastal",
"minimalist",
"rustic",
"hollywood regency",
"midcentury modern",
"modern organic",
"contemporary",
"modern",
"scandinavian",
"eclectic",
"bohemiam",
"industrial",
"traditional",
"transitional",
"farmhouse",
"country",
"asian",
"mediterranean",
"rustic",
"southwestern",
"coastal",
]
interior_prefix = [
"comfortable",
"luxurious",
"simple",
]
rooms = [
"Bathroom",
"Living room",
"Hotel room",
"Lobby",
"Entrance hall",
"Kitchen",
"Family room",
"Master bedroom",
"Bedroom",
"Kids bedroom",
"Laundry room",
"Guest room",
"Home office",
"Library room",
"Playroom",
"Home Theater room",
"Gym room",
"Basement room",
"Garage",
"Walk-in closet",
"Pantry",
"Gaming room",
"Attic",
"Sunroom",
"Storage room",
"Study room",
"Dining room",
"Loft",
"Studio room",
"Appartement",
]
def __init__(self, *, n_rows_to_load: t.Optional[int]) -> None:
"""
Generate a set of initial prompts that will be used to retrieve images from the
LAION-5B dataset.
Args:
n_rows_to_load: Optional argument that defines the number of rows to load.
Useful for testing pipeline runs on a small scale
"""
self.n_rows_to_load = n_rows_to_load
@staticmethod
def make_interior_prompt(room: str, prefix: str, style: str) -> str:
"""Generate a prompt for the interior design model.
Args:
room: room name
prefix: prefix for the room
style: interior style
Returns:
prompt for the interior design model
"""
return f"{prefix.lower()} {room.lower()}, {style.lower()} interior design"
def load(self) -> dd.DataFrame:
import itertools
room_tuples = itertools.product(
self.rooms, self.interior_prefix, self.interior_styles
)
prompts = map(lambda x: self.make_interior_prompt(*x), room_tuples)
pandas_df = pd.DataFrame(prompts, columns=["prompt"])
if self.n_rows_to_load:
pandas_df = pandas_df.head(self.n_rows_to_load)
df = dd.from_pandas(pandas_df, npartitions=1)
return df