-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathui.py
181 lines (145 loc) · 6.88 KB
/
ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from helper import get_image_description, generate_image_from_text
from PIL import Image
import base64
from io import BytesIO
STYLES = [
"Photorealistic and Digital art", "Oil painting", "Watercolor",
"Pencil sketch", "Anime", "Comic book", "Abstract",
"Impressionist", "Pop art"
]
def skeleton_loader():
return """
<div class="skeleton-loader"></div>
<style>
.skeleton-loader {
width: 512px;
height: 512px;
background: linear-gradient(90deg, #f0f0f0 25%, #e0e0e0 50%, #f0f0f0 75%);
background-size: 200% 100%;
animation: loading 1.5s infinite;
}
@keyframes loading {
0% {
background-position: 200% 0;
}
100% {
background-position: -200% 0;
}
}
</style>
"""
def bordered_placeholder():
return """
<div class="bordered-placeholder"></div>
<style>
.bordered-placeholder {
width: 512px;
height: 512px;
border: 2px dashed #cccccc;
}
</style>
"""
def get_image_download_link(img, filename, text):
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
href = f'<a href="data:file/png;base64,{img_str}" download="{filename}">{text}</a>'
return href
def main():
st.set_page_config(layout="wide", page_title="Sketch to Realistic Image Converter")
st.title("Sketch to Realistic Image Converter")
if 'canvas_result' not in st.session_state:
st.session_state.canvas_result = None
col_input, col_controls, col_output = st.columns([2, 1, 2])
with col_input:
st.subheader("Draw your sketch or upload an image")
uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
st.session_state.canvas_result = type('obj', (object,), {'image_data': image})
else:
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=3,
stroke_color="#000000",
background_color="#FFFFFF",
height=512,
width=512,
drawing_mode="freedraw",
key="canvas",
)
if canvas_result.image_data is not None:
st.session_state.canvas_result = canvas_result
with col_controls:
st.subheader("Image Details")
selected_style = st.selectbox("Select output style", STYLES)
additional_info = st.text_input("Enter Additional information that specify the actions")
description_button = st.button(
"Generate Description" if 'description' not in st.session_state else "Regenerate Description",
use_container_width=True
)
generate_button = st.button("Generate Image", use_container_width=True)
loading_placeholder = st.empty()
with col_output:
st.subheader("Generated Realistic Image")
output_placeholder = st.empty()
description_placeholder = st.empty()
if 'realistic_image' in st.session_state:
output_placeholder.image(st.session_state.realistic_image, caption="Generated Realistic Image", width=512)
result = st.session_state.realistic_image
st.markdown(get_image_download_link(result, "generated_image.png", "Download Generated Image"), unsafe_allow_html=True)
else:
output_placeholder.markdown(bordered_placeholder(), unsafe_allow_html=True)
if description_button:
if st.session_state.canvas_result is not None and st.session_state.canvas_result.image_data is not None:
loading_placeholder.text("Generating description...")
if isinstance(st.session_state.canvas_result.image_data, Image.Image):
img = st.session_state.canvas_result.image_data
else:
img = Image.fromarray(st.session_state.canvas_result.image_data.astype('uint8'), 'RGBA')
img = img.convert('RGB')
base_description = get_image_description(img, additional_info, selected_style)
prompt = f"{base_description} \n Style: {selected_style}"
st.session_state.description = prompt
loading_placeholder.empty()
st.success("Description generated successfully!")
else:
st.warning("Please draw something on the canvas or upload an image first!")
if 'description' in st.session_state:
st.session_state.description = description_placeholder.text_area(
"Image Description (edit if needed)",
st.session_state.description,
height=150
)
if generate_button:
if st.session_state.canvas_result is not None and st.session_state.canvas_result.image_data is not None:
if 'description' not in st.session_state:
st.warning("Please generate a description first!")
else:
loading_placeholder.text("Generating...")
output_placeholder.markdown(skeleton_loader(), unsafe_allow_html=True)
if isinstance(st.session_state.canvas_result.image_data, Image.Image):
img = st.session_state.canvas_result.image_data
else:
img = Image.fromarray(st.session_state.canvas_result.image_data.astype('uint8'), 'RGBA')
img = img.convert('RGB')
try:
realistic_image = generate_image_from_text(st.session_state.description, img)
if realistic_image is None:
raise Exception("Image generation returned None")
st.session_state.realistic_image = realistic_image
output_placeholder.image(realistic_image, caption="Generated Realistic Image", width=512)
st.markdown(get_image_download_link(realistic_image, "generated_image.png", "Download Generated Image"), unsafe_allow_html=True)
loading_placeholder.empty()
st.success("Image generated successfully!")
except Exception as e:
loading_placeholder.empty()
output_placeholder.markdown(bordered_placeholder(), unsafe_allow_html=True)
st.error(f"Error generating image: {str(e)}")
else:
st.warning("Please draw something on the canvas or upload an image first!")
if __name__ == "__main__":
main()