forked from facebookarchive/iTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgfx.lua
213 lines (198 loc) · 6.52 KB
/
gfx.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
local ffi = require 'ffi'
local uuid = require 'uuid'
local base64 = require 'base64'
require 'pl.text'.format_operator()
require 'image'
local itorch = require 'itorch._env'
require 'itorch.bokeh'
local util = require 'itorch.util'
-- Example: require 'image';itorch.image(image.scale(image.lena(),16,16))
function itorch.image(img, opts)
assert(itorch._iopub,'itorch._iopub socket not set')
assert(itorch._msg,'itorch._msg not set')
if torch.typename(img) == 'string' then -- assume that it is path
img = image.load(img, 3) -- TODO: revamp this to just directly load the blob, infer file prefix, and send.
end
if torch.isTensor(img) or torch.type(img) == 'table' then
opts = opts or {padding=2}
opts.input = img
local imgDisplay = image.toDisplayTensor(opts)
if imgDisplay:dim() == 2 then
imgDisplay = imgDisplay:view(1, imgDisplay:size(1), imgDisplay:size(2))
end
local tmp = os.tmpname() .. '.png'
image.save(tmp, imgDisplay)
-------------------------------------------------------------
-- load the image back as binary blob
local f = assert(torch.DiskFile(tmp,'r',true)):binary();
f:seekEnd();
local size = f:position()-1
f:seek(1)
local buf = torch.CharStorage(size);
assert(f:readChar(buf) == size, 'wrong number of bytes read')
f:close()
os.execute('rm -f ' .. tmp)
------------------------------------------------------------
local content = {}
content.source = 'itorch'
content.data = {}
content.data['text/plain'] = 'Console does not support images'
content.data['image/png'] = base64.encode(ffi.string(torch.data(buf), size))
content.metadata = { }
content.metadata['image/png'] = {width = imgDisplay:size(3), height = imgDisplay:size(2)}
local m = util.msg('display_data', itorch._msg)
m.content = content
util.ipyEncodeAndSend(itorch._iopub, m)
else
error('unhandled type in itorch.image:' .. torch.type(img))
end
end
local audio_template = [[
<div id="${div_id}"><audio controls src="data:audio/${extension};base64,${base64audio}" /></div>
]]
-- Example: itorch.audio('hello.mp3')
function itorch.audio(fname)
assert(itorch._iopub,'itorch._iopub socket not set')
assert(itorch._msg,'itorch._msg not set')
-- get prefix
local pos = fname:reverse():find('%.')
local ext = fname:sub(#fname-pos + 2)
assert(ext == 'mp3' or ext == 'wav' or ext == 'ogg' or ext == 'aac',
'mp3, wav, ogg, aac files supported. But found extension: ' .. ext)
-- load the audio as binary blob
local f = assert(torch.DiskFile(fname,'r',true),
'File could not be opened: ' .. fname):binary();
f:seekEnd();
local size = f:position()-1
f:seek(1)
local buf = torch.CharStorage(size);
assert(f:readChar(buf) == size, 'wrong number of bytes read')
f:close()
local base64audio = base64.encode(ffi.string(torch.data(buf), size))
local div_id = uuid.new()
local content = {}
content.source = 'itorch'
content.data = {}
content.data['text/html'] =
audio_template % {
div_id = div_id,
extension = ext,
base64audio = base64audio
};
content.metadata = {}
local m = util.msg('display_data', itorch._msg)
m.content = content
util.ipyEncodeAndSend(itorch._iopub, m)
return window_id
end
local video_template = [[
<div id="${div_id}"><video controls src="data:video/${extension};base64,${base64video}" /></div>
]]
-- Example: itorch.video('hello.mp4')
function itorch.video(fname)
assert(itorch._iopub,'itorch._iopub socket not set')
assert(itorch._msg,'itorch._msg not set')
-- get prefix
local pos = fname:reverse():find('%.')
local ext = fname:sub(#fname-pos + 2)
if ext == 'ogv' then ext = 'ogg' end
assert(ext == 'mp4' or ext == 'wav' or ext == 'ogg' or ext == 'webm',
'mp4, ogg, webm files supported. But found extension: ' .. ext)
-- load the video as binary blob
local f = assert(torch.DiskFile(fname,'r',true),
'File could not be opened: ' .. fname):binary();
f:seekEnd();
local size = f:position()-1
f:seek(1)
local buf = torch.CharStorage(size);
assert(f:readChar(buf) == size, 'wrong number of bytes read')
f:close()
local base64video = base64.encode(ffi.string(torch.data(buf), size))
local div_id = uuid.new()
local content = {}
content.source = 'itorch'
content.data = {}
content.data['text/html'] =
video_template % {
div_id = div_id,
extension = ext,
base64video = base64video
};
content.metadata = {}
local m = util.msg('display_data', itorch._msg)
m.content = content
util.ipyEncodeAndSend(itorch._iopub, m)
return window_id
end
function itorch.lena()
itorch.image(image.lena())
end
local html_template =
[[
<script type="text/javascript">
$(function() {
$("#${window_id}").html('${html_content}'); // clear any previous plot in window_id
});
</script>
<div id="${div_id}"></div>
]]
local function escape_js(s)
-- single quite
s = s:gsub("'","\\'")
-- double quote
s = s:gsub('"','\\"')
-- backslash
s = s:gsub("\\","\\\\")
-- newline
s = s:gsub("\n","\\n")
-- carriage return
s = s:gsub("\r","\\r")
-- tab
s = s:gsub("\t","\\t")
-- backspace
s = s:gsub("\b","\\b")
-- form feed
s = s:gsub("\f","\\f")
return s
end
function itorch.html(html, window_id)
html = escape_js(html)
assert(itorch._iopub,'itorch._iopub socket not set')
assert(itorch._msg,'itorch._msg not set')
local div_id = uuid.new()
window_id = window_id or div_id
local content = {}
content.source = 'itorch'
content.data = {}
content.data['text/html'] =
html_template % {
html_content = html,
window_id = window_id,
div_id = div_id
};
content.metadata = {}
-- send displayData
local m = util.msg('display_data', itorch._msg)
m.content = content
util.ipyEncodeAndSend(itorch._iopub, m)
return window_id
end
local ok,err = pcall(function() require 'xlua' end)
if ok then
local progress_warning = false
local last = os.clock()
xlua.progress = function(i, n)
-- make progress bar really slow in itorch
if os.clock() - last > 15 then -- 15 seconds
print('Progress: ' .. i .. ' / ' .. n)
last = os.clock()
end
-- local m = util.msg('clear_output', itorch._msg)
-- m.content = {}
-- m.content.wait = true
-- m.content.display = true
-- util.ipyEncodeAndSend(itorch._iopub, m)
-- itorch.html(progress_template % {})
end
end
return itorch;