diff --git a/deploy/configs/inference_ram.yaml b/deploy/configs/inference_ram.yaml new file mode 100644 index 0000000000..0c513047f5 --- /dev/null +++ b/deploy/configs/inference_ram.yaml @@ -0,0 +1,36 @@ +Global: + infer_imgs: "docs/images/inference_deployment/whl_demo.jpg" + inference_model_dir: "./inference" + batch_size: 1 + use_gpu: False + enable_mkldnn: False + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: False # do not set it as True since there is a bug which leads the invaild initilize for predictor + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False + +PreProcess: + transform_ops: + - ResizeImage: + resize_short: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: "" + channel_num: 3 + - ToCHWImage: +PostProcess: + main_indicator: RamOutPut + RamOutPut: + language: "en" + tag_list: "ppcls/utils/ram/ram_tag_list.txt" + tag_list_chinese: "ppcls/utils/ram/ram_tag_list_chinese.txt" + ram_class_threshold_path: "ppcls/utils/ram/ram_tag_list_threshold.txt" + + diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index e6709dcef9..ca3d1facca 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -71,7 +71,7 @@ def __init__(self, func_list, main_indicator="Topk"): def __call__(self, x, image_file=None): rtn = None for func in self.func_list: - tmp = func(x, image_file) + tmp = func(*x, image_file) if type(func).__name__ in self.main_indicator: rtn = tmp return rtn @@ -496,3 +496,74 @@ def __call__(self, batch_preds, file_names=None): ).astype(np.int8).tolist() batch_res.append({"attributes": label_res, "output": pred_res}) return batch_res + + +class RamOutPut(object): + def __init__(self, + language="cn", + tag_list="", + tag_list_chinese="", + threshold=0.68, + delete_tag_index=[], + ram_class_threshold_path=""): + self.language = language + assert tag_list, tag_list_chinese + self.tag_list = self.load_tag_list(tag_list) + self.delete_tag_index = delete_tag_index + self.tag_list_chinese = self.load_tag_list(tag_list_chinese) + self.num_class = len(self.tag_list) + self.class_threshold = paddle.ones([self.num_class]) * threshold + with open(ram_class_threshold_path, "r", encoding="utf-8") as f: + ram_class_threshold = [float(s.strip()) for s in f] + for key, value in enumerate(ram_class_threshold): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, "r", encoding="utf-8") as f: + tag_list = f.read().splitlines() + tag_list = np.array(tag_list) + return tag_list + + def __call__(self, logits, bs, file_names=None): + batch_res = [] + if bs is None: + if len(logits.shape) < 2: + bs = 1 + else: + bs = logits.shape[0] + logits = paddle.to_tensor(logits).reshape([bs,-1]) + targets = paddle.where( + F.sigmoid(logits) > self.class_threshold, + paddle.to_tensor([1.0]), paddle.zeros(self.num_class)) + targets = targets.reshape([bs, -1]) + res = {} + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + tag_output = [] + tag_output_chinese = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + tag_output.append(" | ".join(token)) + token_chinese = self.tag_list_chinese[index].squeeze(axis=1) + tag_output_chinese.append(" | ".join(token_chinese)) + res["cn"] = tag_output_chinese + res["en"] = tag_output + res["all"] = f"en : {tag_output}, cn: {tag_output_chinese}" + + scores = F.sigmoid(logits).numpy() + class_ids_list = [] + scores_list = [] + + for b in range(bs): + index = np.argwhere(tag[b] == 1) + class_ids_list.append(index.tolist()) + scores_list.append(scores[b][index].tolist()) + + outputformat = { + "class_ids": class_ids_list, + "scores": scores_list, + "label_names": res[self.language] + } + batch_res.append(outputformat) + return outputformat diff --git a/deploy/python/predict_multimodal.py b/deploy/python/predict_multimodal.py new file mode 100644 index 0000000000..6397e734b8 --- /dev/null +++ b/deploy/python/predict_multimodal.py @@ -0,0 +1,68 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import cv2 +import numpy as np + +from paddleclas.deploy.utils import logger, config +from paddleclas.deploy.utils.get_image_list import get_image_list +from paddleclas.deploy.python.predict_cls import ClsPredictor + + +def main(config): + cls_predictor = ClsPredictor(config) + image_list = get_image_list(config["Global"]["infer_imgs"]) + + batch_imgs = [] + batch_names = [] + cnt = 0 + for idx, img_path in enumerate(image_list): + img = cv2.imread(img_path) + if img is None: + logger.warning( + "Image file failed to read and has been skipped. The path: {}". + format(img_path)) + else: + img = img[:, :, ::-1] + print(img.shape) + batch_imgs.append(img) + img_name = os.path.basename(img_path) + batch_names.append(img_name) + cnt += 1 + + if cnt % config["Global"]["batch_size"] == 0 or (idx + 1 + ) == len(image_list): + if len(batch_imgs) == 0: + continue + batch_results = cls_predictor.predict(batch_imgs) + for number, result_dict in enumerate(batch_results): + if len(batch_imgs) == 0: + continue + for number, result_key in enumerate(batch_results.keys()): + print( + f"{img_name}-{result_key}:{batch_results[result_key]}") + batch_imgs = [] + batch_names = [] + if cls_predictor.benchmark: + cls_predictor.auto_logger.report() + return + + + +if __name__ == "__main__": + args = config.parse_args() + config = config.get_config(args.config, overrides=args.override, show=True) + main(config) diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 7800da33ba..c1151732fb 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -20,6 +20,7 @@ from . import backbone, gears from .backbone import * +from .ram import * from .gears import build_gear, add_ml_decoder_head from .utils import * from .backbone.base.theseus_layer import TheseusLayer diff --git a/ppcls/arch/backbone/clip/__init__.py b/ppcls/arch/backbone/clip/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ppcls/arch/clip/__init__.py b/ppcls/arch/clip/__init__.py new file mode 100644 index 0000000000..014e71984b --- /dev/null +++ b/ppcls/arch/clip/__init__.py @@ -0,0 +1,2 @@ +from .clip import CLIP_vit_base_patch16_224_with_TextEncoder, CLIP_vit_base_patch32_224_with_TextEncoder, CLIP_vit_large_patch14_224_with_TextEncoder, CLIP_vit_large_patch16_224_with_TextEncoder +from .tokenizer import Tokenizer \ No newline at end of file diff --git a/ppcls/arch/clip/clip.py b/ppcls/arch/clip/clip.py new file mode 100644 index 0000000000..fa2352aae4 --- /dev/null +++ b/ppcls/arch/clip/clip.py @@ -0,0 +1,596 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Code was based on https://github.com/AgentMaker/Paddle-CLIP, https://github.com/openai/CLIP/ +# reference: https://arxiv.org/abs/2103.00020 + +import math + +import paddle +import paddle.nn as nn +from paddle.nn import functional as F +from paddle.nn.initializer import Assign, Normal, Constant + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, inputs): + return inputs + + +class QuickGELU(nn.Layer): + def forward(self, x): + return x * nn.functional.sigmoid(1.702 * x) + + + + +class Attention(nn.Layer): + def __init__( + self, + embed_dim, + num_heads=8, + output_dim=None, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert embed_dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias_attr=qkv_bias) + + if self.scaled_cosine: + self.logit_scale = paddle.log(10 * paddle.ones((num_heads, 1, 1))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = paddle.ones((num_heads, 1, 1)) + else: + self.head_scale = None + self.out_proj = nn.Linear(embed_dim, output_dim) if output_dim else nn.Linear(embed_dim, embed_dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask=None): + L, N, C = paddle.shape(x) + q, k, v = self.in_proj(x).chunk(3, axis=-1) + + q = q.reshape([L, N * self.num_heads, -1]).transpose([1, 0, 2]) * self.scale + k = k.reshape([L, N * self.num_heads, -1]).transpose([1, 0, 2]) * self.scale + v = v.reshape([L, N * self.num_heads, -1]).transpose([1, 0, 2]) * self.scale + + if self.logit_scale is not None: + attn = paddle.bmm(F.normalize(q, dim=-1), F.normalize(k, axis=-1).transpose([0, 2, 1])) + logit_scale = paddle.clip(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.reshape([N, self.num_heads, L, L]) * logit_scale + else: + q = q * self.scale + attn = paddle.bmm(q, k.transpose([0, 2, 1])) + + if attn_mask is not None: + if attn_mask.dtype == paddle.bool: + new_attn_mask = paddle.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = F.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + attn = attn.reshape([-1, L, L]) + x = paddle.bmm(attn, v) + if self.head_scale is not None: + x = x.reshape([N, self.num_heads, L, C]) * self.head_scale + x = x.reshape([-1, L, C]) + x = x.transpose([1, 0, 2]).reshape([L, N, C]) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class Bottleneck(nn.Layer): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + self.conv1 = nn.Conv2D(inplanes, planes, 1, bias_attr=False) + self.bn1 = nn.BatchNorm2D(planes) + + self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) + self.bn2 = nn.BatchNorm2D(planes) + + self.avgpool = nn.AvgPool2D(stride) if stride > 1 else Identity() + + self.conv3 = nn.Conv2D( + planes, planes * self.expansion, 1, bias_attr=False) + self.bn3 = nn.BatchNorm2D(planes * self.expansion) + + self.relu = nn.ReLU() + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + self.downsample = nn.Sequential( + ("-1", nn.AvgPool2D(stride)), ("0", nn.Conv2D( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias_attr=False)), + ("1", nn.BatchNorm2D(planes * self.expansion))) + + def forward(self, x): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2D(nn.Layer): + def __init__(self, spacial_dim, embed_dim, num_heads, output_dim=None): + super().__init__() + positional_embedding = self.create_parameter( + shape=(spacial_dim**2 + 1, embed_dim), + default_initializer=Assign( + paddle.randn((spacial_dim**2 + 1, embed_dim)) / embed_dim + **0.5)) + self.add_parameter("positional_embedding", positional_embedding) + + self.attn = Attention(embed_dim, num_heads, output_dim) + + def forward(self, x): + x = x.reshape((x.shape[0], x.shape[1], + x.shape[2] * x.shape[3])).transpose((2, 0, 1)) + x = paddle.concat([x.mean(axis=0, keepdim=True), x], axis=0) + x = x + self.positional_embedding.unsqueeze(1) + x = x.transpose((1, 0, 2)) + x = self.attn(query=x, key=x, value=x) + x = x.transpose((1, 0, 2)) + return x[0] + + +class ModifiedResNet(nn.Layer): + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + self.conv1 = nn.Conv2D( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias_attr=False) + self.bn1 = nn.BatchNorm2D(width // 2) + + self.conv2 = nn.Conv2D( + width // 2, width // 2, kernel_size=3, padding=1, bias_attr=False) + self.bn2 = nn.BatchNorm2D(width // 2) + + self.conv3 = nn.Conv2D( + width // 2, width, kernel_size=3, padding=1, bias_attr=False) + self.bn3 = nn.BatchNorm2D(width) + + self.avgpool = nn.AvgPool2D(2) + self.relu = nn.ReLU() + + # residual layers + self._inplanes = width + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 + self.attnpool = AttentionPool2D(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def stem(self, x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + return x + + +class ResidualAttentionBlock(nn.Layer): + def __init__(self, d_model, n_head, attn_mask=None): + super().__init__() + self.attn = Attention(d_model, n_head) + self.ln_1 = nn.LayerNorm(d_model) + self.mlp = nn.Sequential(("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model))) + self.ln_2 = nn.LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x): + return self.attn(x, attn_mask=self.attn_mask) + + def forward(self, x): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Layer): + def __init__(self, width, layers, heads, attn_mask=None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x): + return self.resblocks(x) + + +class VisualTransformer(nn.Layer): + def __init__(self, input_resolution, patch_size, width, layers, heads, + output_dim): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2D( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias_attr=False) + + scale = width**-0.5 + + class_embedding = self.create_parameter( + shape=(width, ), + default_initializer=Assign(scale * paddle.randn((width, )))) + self.add_parameter("class_embedding", class_embedding) + + positional_embedding = self.create_parameter( + shape=(width, ), + default_initializer=Assign(scale * paddle.randn(( + (input_resolution // patch_size)**2 + 1, width)))) + self.add_parameter("positional_embedding", positional_embedding) + + self.ln_pre = nn.LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = nn.LayerNorm(width) + + proj = self.create_parameter( + shape=(width, ), + default_initializer=Assign(scale * paddle.randn(( + (width, output_dim))))) + self.add_parameter("proj", proj) + + def forward(self, x): + x = self.conv1(x) + x = x.reshape((x.shape[0], x.shape[1], -1)) + x = x.transpose((0, 2, 1)) + zeros = paddle.zeros((x.shape[0], 1, x.shape[-1]), dtype='float32') + x = paddle.concat([self.class_embedding + zeros, x], axis=1) + x = x + self.positional_embedding + x = self.ln_pre(x) + x = self.transformer(x) + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Layer): + def __init__( + self, + embed_dim, + # vision + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + # text + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers): + super().__init__() + self.context_length = context_length + self.embed_dim = embed_dim + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + + positional_embedding = self.create_parameter( + shape=(self.context_length, transformer_width), + default_initializer=Assign( + paddle.empty((self.context_length, transformer_width)))) + self.add_parameter("positional_embedding", positional_embedding) + + self.ln_final = nn.LayerNorm(transformer_width) + + text_projection = self.create_parameter( + shape=(transformer_width, embed_dim), + default_initializer=Assign( + paddle.empty((transformer_width, embed_dim)))) + self.add_parameter("text_projection", text_projection) + + logit_scale = self.create_parameter( + shape=(1, ), default_initializer=Assign(paddle.ones([1]))) + self.add_parameter("logit_scale", logit_scale) + + self.initialize_parameters() + + def initialize_parameters(self): + Normal(std=0.02)(self.token_embedding.weight) + Normal(std=0.01)(self.positional_embedding) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.embed_dim**-0.5 + normal_ = Normal(std=std) + normal_(self.visual.attnpool.attn.q_proj.weight) + normal_(self.visual.attnpool.attn.k_proj.weight) + normal_(self.visual.attnpool.attn.v_proj.weight) + normal_(self.visual.attnpool.attn.out_proj.weight) + + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + Constant(value=0.0)(param) + + proj_std = (self.transformer.width ** -0.5) * \ + ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + + for resblock in self.transformer.resblocks: + normal_ = Normal(std=attn_std) + normal_(resblock.attn.in_proj.weight) + Normal(std=proj_std)(resblock.attn.out_proj.weight) + Normal(std=fc_std)(resblock.mlp.c_fc.weight) + Normal(std=proj_std)(resblock.mlp.c_proj.weight) + + if self.text_projection is not None: + Normal(std=self.transformer.width**-0.5)(self.text_projection) + + def build_attention_mask(self): + mask = paddle.full((self.context_length, self.context_length), + float("-inf")) + mask = paddle.triu(mask, diagonal=1) + return mask + + def text_global_pool(self, x, text=None, pool_type='first'): + if pool_type == 'first': + pooled, tokens = x[:, 0], x[:, 1:] + elif pool_type == 'last': + pooled, tokens = x[:, -1], x[:, :-1] + elif pool_type == 'argmax': + # take features from the eot embedding (eot_token is the highest number in each sequence) + assert text is not None + index = paddle.to_tensor( [paddle.arange(x.shape[0]), text.argmax(axis=-1)]) + pooled, tokens = paddle.index_select(x, index), x + else: + pooled = tokens = x + + return pooled, tokens + + def encode_image(self, image): + return self.visual(image) + + def encode_text(self, text): + x = self.token_embedding(text) + x = x + self.positional_embedding + x = x.transpose([1, 0, 2]) + x = self.transformer(x) + x = x.transpose([1, 0, 2]) + x = self.ln_final(x) + + x, _ = self.text_global_pool(x , text) + if self.text_projection is not None: + if isinstance(self.text_projection, paddle.nn.Linear): + x = self.text_projection(x) + else: + x = x @ self.text_projection + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / \ + image_features.norm(axis=-1, keepdim=True) + text_features = text_features / \ + text_features.norm(axis=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @text_features.t() + logits_per_text = logit_scale * text_features @image_features.t() + + # unify the format for paddle loss + results = {"image": logits_per_image, "text":logits_per_text} + return results + + + +def tokenize(texts, tokenizer, context_length=77): + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = tokenizer.encoder["<|startoftext|>"] + eot_token = tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] + for text in texts] + result = paddle.zeros((len(all_tokens), context_length), dtype='int64') + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + raise RuntimeError( + f"Input {texts[i]} is too long for context length {context_length}" + ) + result[i, :len(tokens)] = paddle.to_tensor(tokens) + + return result + + +def CLIP_vit_base_patch32_224_with_TextEncoder(**kwargs): + model = CLIP( + embed_dim=512, + image_resolution=224, + vision_layers=12, + vision_width=768, + vision_patch_size=32, + context_length=77, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12) + return model + + +def CLIP_vit_base_patch16_224_with_TextEncoder(**kwargs): + model = CLIP( + embed_dim=512, + image_resolution=224, + vision_layers=12, + vision_width=768, + vision_patch_size=16, + context_length=77, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12) + return model + + +def CLIP_vit_large_patch14_224_with_TextEncoder(**kwargs): + model = CLIP( + embed_dim=768, + image_resolution=224, + vision_layers=24, + vision_width=1024, + vision_patch_size=14, + context_length=77, + vocab_size=49408, + transformer_width=768, + transformer_heads=12, + transformer_layers=12) + return model + + +def CLIP_vit_large_patch16_224_with_TextEncoder(**kwargs): + model = CLIP( + embed_dim=768, + image_resolution=224, + vision_layers=24, + vision_width=1024, + vision_patch_size=16, + context_length=77, + vocab_size=49408, + transformer_width=768, + transformer_heads=12, + transformer_layers=12) + return model + + +CLIP_DICT = { + "vit-b-32-224": CLIP_vit_base_patch32_224_with_TextEncoder(), + "vit-b-16-224": CLIP_vit_base_patch16_224_with_TextEncoder(), + "vit-l-16-224": CLIP_vit_base_patch16_224_with_TextEncoder(), + "vit-l-14-224": CLIP_vit_large_patch14_224_with_TextEncoder(), +} diff --git a/ppcls/arch/clip/tokenizer.py b/ppcls/arch/clip/tokenizer.py new file mode 100644 index 0000000000..084d8148c0 --- /dev/null +++ b/ppcls/arch/clip/tokenizer.py @@ -0,0 +1,160 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Code was based on https://github.com/AgentMaker/Paddle-CLIP, https://github.com/openai/CLIP/ +# reference: https://arxiv.org/abs/2103.00020 + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return "ppcls/utils/ram/bpe_simple_vocab_16e6.txt.gz" + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1)) + \ + list(range(ord("¡"), ord("¬")+1)) + \ + list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class Tokenizer(object): + def __init__(self, bpe_path: str=default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors="replace").replace('', ' ') + return text diff --git a/ppcls/arch/ram/__init__.py b/ppcls/arch/ram/__init__.py new file mode 100644 index 0000000000..c2bff3a624 --- /dev/null +++ b/ppcls/arch/ram/__init__.py @@ -0,0 +1,2 @@ +from .ram import ram +from .ram_plus import ram_plus \ No newline at end of file diff --git a/ppcls/arch/ram/bert.py b/ppcls/arch/ram/bert.py new file mode 100644 index 0000000000..33f882767f --- /dev/null +++ b/ppcls/arch/ram/bert.py @@ -0,0 +1,1397 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Code was based on https://github.com/PaddlePaddle/PaddleNLP/ + +import warnings +from typing import Optional, Tuple +import math +from dataclasses import dataclass + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import Tensor +from paddle.nn import Layer + +try: + from paddle.incubate.nn import FusedTransformerEncoderLayer +except ImportError: + FusedTransformerEncoderLayer = None +from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model, prune_linear_layer, find_pruneable_heads_and_indices, apply_chunking_to_forward +from paddlenlp.layers import Linear as TransposedLinear +from paddlenlp.utils.converter import StateDictNameMapping, init_name_mappings +from paddlenlp.utils.env import CONFIG_NAME +from paddlenlp.transformers.activations import ACT2FN +from paddlenlp.transformers.model_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput,) +from paddlenlp.transformers.bert.configuration import ( + BERT_PRETRAINED_INIT_CONFIGURATION, + BERT_PRETRAINED_RESOURCE_FILES_MAP, + BertConfig, ) + +__all__ = [ + "BertModel", + "BertPretrainedModel", + "BertForPretraining", + "BertPretrainingCriterion", + "BertPretrainingHeads", + "BertLMHeadModel", + "BertForMaskedLM", +] + + +class BertSelfAttention(nn.Layer): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.shape[:-1] + [ + self.num_attention_heads, self.attention_head_size + ] + x = x.reshape(new_x_shape) + return x.transpose([0, 2, 1, 3]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + # print(self.key.weight.shape) + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = paddle.concat([past_key_value[0], key_layer], axis=2) + value_layer = paddle.concat( + [past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # compatible with higher versions of transformers + if key_layer.shape[0] > query_layer.shape[0]: + key_layer = key_layer[:query_layer.shape[0], :, :, :] + attention_mask = attention_mask[:query_layer.shape[0], :, :] + value_layer = value_layer[:query_layer.shape[0], :, :, :] + + # Take the dot product between "query" and "key" to get the raw attention scores. + key = key_layer.transpose([0, 1, 3, 2]) + attention_scores = paddle.matmul(query_layer, key) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.shape[1] + position_ids_l = paddle.arange( + seq_length, dtype=paddle.int32).reshape([-1, 1]) + position_ids_r = paddle.arange( + seq_length, dtype=paddle.int32).reshape([-1, 1]) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = paddle.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = paddle.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = paddle.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(axis=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = paddle.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.transpose([0, 2, 1, 3]) + new_context_layer_shape = context_layer.shape[:-2] + [ + self.all_head_size, + ] + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else ( + context_layer, ) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Layer): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Layer): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Layer): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Layer): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Layer): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, ): + + if mode == 'tagging': + + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, ) + attention_output = cross_attention_outputs[0] + outputs = cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + present_key_value = cross_attention_outputs[-1] + + else: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode == 'multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1: + -1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, + self.seq_len_dim, attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.LayerList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = paddle.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1], + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) + + +class BertEmbeddings(Layer): + """ + Include embeddings from word, position and token_type embeddings + """ + + def __init__(self, config: BertConfig): + super(BertEmbeddings, self).__init__() + + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.register_buffer( + "position_ids", + paddle.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute") + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor]=None, + position_ids: Optional[Tensor]=None, + inputs_embeds=None, + past_key_values_length: Optional[int]=None, ): + if input_ids is not None: + input_shape = input_ids.shape + else: + input_shape = inputs_embeds.shape[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if token_type_ids is None: + token_type_ids = paddle.zeros_like(input_ids, dtype="int64") + + input_embedings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = input_embedings + position_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertPooler(Layer): + """ + Pool the result of BertEncoder. + """ + + def __init__(self, config: BertConfig): + """init the bert pooler with config & args/kwargs + + Args: + config (BertConfig): BertConfig instance. Defaults to None. + """ + super(BertPooler, self).__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + self.pool_act = config.pool_act + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + if self.pool_act == "tanh": + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPretrainedModel(PretrainedModel): + """ + An abstract class for pretrained BERT models. It provides BERT related + `model_config_file`, `resource_files_names`, `pretrained_resource_files_map`, + `pretrained_init_configuration`, `base_model_prefix` for downloading and + loading pretrained models. + See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. + """ + + model_config_file = CONFIG_NAME + config_class = BertConfig + resource_files_names = {"model_state": "model_state.pdparams"} + base_model_prefix = "bert" + + pretrained_init_configuration = BERT_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = BERT_PRETRAINED_RESOURCE_FILES_MAP + + @classmethod + def _get_name_mappings(cls, + config: BertConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + "embeddings.word_embeddings.weight", + "embeddings.position_embeddings.weight", + "embeddings.token_type_embeddings.weight", + ["embeddings.LayerNorm.weight", "embeddings.layer_norm.weight"], + ["embeddings.LayerNorm.bias", "embeddings.layer_norm.bias"], + ["pooler.dense.weight", None, "transpose"], + "pooler.dense.bias", + # for TokenClassification + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [ + f"encoder.layer.{layer_index}.attention.self.query.weight", + f"encoder.layers.{layer_index}.self_attn.q_proj.weight", + "transpose", + ], + [ + f"encoder.layer.{layer_index}.attention.self.query.bias", + f"encoder.layers.{layer_index}.self_attn.q_proj.bias", + ], + [ + f"encoder.layer.{layer_index}.attention.self.key.weight", + f"encoder.layers.{layer_index}.self_attn.k_proj.weight", + "transpose", + ], + [ + f"encoder.layer.{layer_index}.attention.self.key.bias", + f"encoder.layers.{layer_index}.self_attn.k_proj.bias", + ], + [ + f"encoder.layer.{layer_index}.attention.self.value.weight", + f"encoder.layers.{layer_index}.self_attn.v_proj.weight", + "transpose", + ], + [ + f"encoder.layer.{layer_index}.attention.self.value.bias", + f"encoder.layers.{layer_index}.self_attn.v_proj.bias", + ], + [ + f"encoder.layer.{layer_index}.attention.output.dense.weight", + f"encoder.layers.{layer_index}.self_attn.out_proj.weight", + "transpose", + ], + [ + f"encoder.layer.{layer_index}.attention.output.dense.bias", + f"encoder.layers.{layer_index}.self_attn.out_proj.bias", + ], + [ + f"encoder.layer.{layer_index}.intermediate.dense.weight", + f"encoder.layers.{layer_index}.linear1.weight", + "transpose", + ], + [ + f"encoder.layer.{layer_index}.intermediate.dense.bias", + f"encoder.layers.{layer_index}.linear1.bias" + ], + [ + f"encoder.layer.{layer_index}.attention.output.LayerNorm.weight", + f"encoder.layers.{layer_index}.norm1.weight", + ], + [ + f"encoder.layer.{layer_index}.attention.output.LayerNorm.bias", + f"encoder.layers.{layer_index}.norm1.bias", + ], + [ + f"encoder.layer.{layer_index}.output.dense.weight", + f"encoder.layers.{layer_index}.linear2.weight", + "transpose", + ], + [ + f"encoder.layer.{layer_index}.output.dense.bias", + f"encoder.layers.{layer_index}.linear2.bias" + ], + [ + f"encoder.layer.{layer_index}.output.LayerNorm.weight", + f"encoder.layers.{layer_index}.norm2.weight" + ], + [ + f"encoder.layer.{layer_index}.output.LayerNorm.bias", + f"encoder.layers.{layer_index}.norm2.bias" + ], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(model_mappings) + + # base-model prefix "BertModel" + if "BertModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "bert." + mapping[0] + mapping[1] = "bert." + mapping[1] + + # downstream mappings + if "BertForQuestionAnswering" in config.architectures: + model_mappings.extend( + [["qa_outputs.weight", "classifier.weight", "transpose"], + ["qa_outputs.bias", "classifier.bias"]]) + if ("BertForMultipleChoice" in config.architectures or + "BertForSequenceClassification" in config.architectures or + "BertForTokenClassification" in config.architectures): + model_mappings.extend( + [["classifier.weight", "classifier.weight", "transpose"]]) + + mappings = [ + StateDictNameMapping( + *mapping, index=index) + for index, mapping in enumerate(model_mappings) + ] + return mappings + + def _init_weights(self, layer): + """Initialization hook""" + if isinstance(layer, (nn.Linear, nn.Embedding)): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range, + shape=layer.weight.shape, )) + + elif isinstance(layer, nn.LayerNorm): + layer._epsilon = self.config.layer_norm_eps + + +@register_base_model +class BertModel(BertPretrainedModel): + """ + The bare BERT Model transformer outputting raw hidden-states. + + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + config (:class:`BertConfig`): + An instance of BertConfig used to construct BertModel. + """ + + def __init__(self, config: BertConfig, add_pooling_layer=True): + super(BertModel, self).__init__(config) + + self.pad_token_id = config.pad_token_id + self.initializer_range = config.initializer_range + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + if config.fuse and FusedTransformerEncoderLayer is None: + warnings.warn( + "FusedTransformerEncoderLayer is not supported by the running Paddle. " + "The flag fuse_transformer will be ignored. Try Paddle >= 2.3.0" + ) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def get_extended_attention_mask(self, + attention_mask: Tensor, + input_shape: Tuple[int], + is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`paddle.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`paddle.device`): + The device of the input to the model. + + Returns: + :obj:`paddle.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = paddle.arange(seq_length) + causal_mask = seq_ids[None, None, :].tile( + [batch_size, seq_length, 1]) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.astype(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = paddle.concat( + [ + paddle.ones( + (batch_size, seq_length, prefix_seq_len), + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, ) + + extended_attention_mask = causal_mask[:, + None, :, :] * attention_mask[:, + None, + None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})". + format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.astype( + dtype=paddle.float16) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def get_head_mask(self, + head_mask: Optional[Tensor], + num_hidden_layers: int, + is_attention_chunked: bool=False) -> Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`paddle.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + is_attention_chunked: (`bool`, *optional*, defaults to `False`): + Whether or not the attentions scores are computed by chunks or not. + + Returns: + `paddle.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, + num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if len(head_mask.shape) == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( + -1).unsqueeze(-1) + head_mask = head_mask.expand([num_hidden_layers, -1, -1, -1, -1]) + elif len(head_mask.shape) == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1) # We can specify head_mask for each layer + assert len( + head_mask. + shape) == 5, f"head_mask.dim != 5, instead {len(head_mask.shape)}" + head_mask = head_mask.astype( + dtype=self.dtype) # switch to float if need + fp16 compatibility + return head_mask + + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`paddle.Tensor`): An attention mask. + + Returns: + `paddle.Tensor`: The inverted attention mask. + """ + if len(encoder_attention_mask.shape) == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, + None, :, :] + if len(encoder_attention_mask.shape) == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, + None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + + encoder_extended_attention_mask = ( + 1.0 - encoder_extended_attention_mask + ) * paddle.finfo(paddle.float16).min + + return encoder_extended_attention_mask + + def forward( + self, + input_ids: Tensor=None, + token_type_ids: Optional[Tensor]=None, + position_ids: Optional[Tensor]=None, + attention_mask: Optional[Tensor]=None, + past_key_values: Optional[Tuple[Tuple[Tensor]]]=None, + encoder_embeds=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache: Optional[bool]=None, + output_hidden_states: Optional[bool]=None, + output_attentions: Optional[bool]=None, + return_dict: Optional[bool]=None, + is_decoder=False, + mode='multimodal', ): + r""" + The BertModel forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. + token_type_ids (Tensor, optional): + Segment token indices to indicate different portions of the inputs. + Selected in the range ``[0, type_vocab_size - 1]``. + If `type_vocab_size` is 2, which means the inputs have two portions. + Indices can either be 0 or 1: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. + Defaults to `None`, which means we don't add segment embeddings. + position_ids(Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + max_position_embeddings - 1]``. + Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`. + attention_mask (Tensor, optional): + Mask used in multi-head attention to avoid performing attention on to some unwanted positions, + usually the paddings or the subsequent positions. + Its data type can be int, float and bool. + When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. + When the data type is int, the `masked` tokens have `0` values and the others have `1` values. + When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values. + It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. + Defaults to `None`, which means nothing needed to be prevented attention to. + past_key_values (tuple(tuple(Tensor)), optional): + The length of tuple equals to the number of layers, and each inner + tuple haves 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`) + which contains precomputed key and value hidden states of the attention blocks. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, optional): + If set to `True`, `past_key_values` key value states are returned. + Defaults to `None`. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `None`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `None`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output + will be a tuple of tensors. Defaults to `None`. + + Returns: + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if + `return_dict=True`. Otherwise it returns a tuple of tensors corresponding + to ordered and not None (depending on the input arguments) fields of + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import BertModel, BertTokenizer + + tokenizer = BertTokenizer.from_pretrained('bert-wwm-chinese') + model = BertModel.from_pretrained('bert-wwm-chinese') + + inputs = tokenizer("欢迎使用百度飞桨!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + output = model(**inputs) + """ + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.shape + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + batch_size, seq_length = input_shape + elif encoder_embeds is not None: + input_shape = encoder_embeds.shape[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds or encoder_embeds" + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].shape + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.shape + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = paddle.ones(encoder_hidden_shape) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + if attention_mask is None: + attention_mask = paddle.ones( + (batch_size, seq_length + past_key_values_length)) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, ) + else: + embedding_output = encoder_embeds + + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, is_decoder) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, ) + + + + + +class BertLMPredictionHead(Layer): + """ + Bert Model with a `language modeling` head on top for CLM fine-tuning. + """ + + def __init__(self, config: BertConfig): + super(BertLMPredictionHead, self).__init__() + + self.transform = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = getattr(nn.functional, config.hidden_act) + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.decoder = TransposedLinear(config.hidden_size, config.vocab_size) + # link bias to load pretrained weights + self.decoder_bias = self.decoder.bias + + def forward(self, hidden_states, masked_positions=None): + if masked_positions is not None: + hidden_states = paddle.reshape(hidden_states, + [-1, hidden_states.shape[-1]]) + hidden_states = paddle.tensor.gather(hidden_states, + masked_positions) + # gather masked tokens might be more quick + hidden_states = self.transform(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + + + + +class BertOnlyMLMHead(nn.Layer): + def __init__(self, config: BertConfig): + super().__init__() + self.predictions = BertLMPredictionHead(config=config) + + def forward(self, sequence_output, masked_positions=None): + prediction_scores = self.predictions(sequence_output, masked_positions) + return prediction_scores + + +class BertForMaskedLM(BertPretrainedModel): + """ + Bert Model with a `masked language modeling` head on top. + + Args: + config (:class:`BertConfig`): + An instance of BertConfig used to construct BertForMaskedLM. + + """ + + def __init__(self, config: BertConfig): + super(BertForMaskedLM, self).__init__(config) + self.bert = BertModel(config) + + self.cls = BertOnlyMLMHead(config=config) + self.tie_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor]=None, + position_ids: Optional[Tensor]=None, + attention_mask: Optional[Tensor]=None, + masked_positions: Optional[Tensor]=None, + labels: Optional[Tensor]=None, + output_hidden_states: Optional[bool]=None, + output_attentions: Optional[bool]=None, + return_dict: Optional[bool]=None, ): + r""" + + Args: + input_ids (Tensor): + See :class:`BertModel`. + token_type_ids (Tensor, optional): + See :class:`BertModel`. + position_ids (Tensor, optional): + See :class:`BertModel`. + attention_mask (Tensor, optional): + See :class:`BertModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., vocab_size]` + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `None`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `None`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `None`. + + Returns: + An instance of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import BertForMaskedLM, BertTokenizer + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertForMaskedLM.from_pretrained('bert-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + + logits = model(**inputs) + print(logits.shape) + # [1, 13, 30522] + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.bert( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output, + masked_positions=masked_positions) + + masked_lm_loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss( + ) # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.reshape((-1, prediction_scores.shape[-1])), + labels.reshape((-1, ))) + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return (((masked_lm_loss, ) + output) if masked_lm_loss is not None + else (output[0] if len(output) == 1 else output)) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) + + +class BertLMHeadModel(BertPretrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [ + r"position_ids", r"predictions.decoder.bias" + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', ): + r""" + encoder_hidden_states (:obj:`paddle.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`paddle.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`paddle.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(paddle.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import paddle + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :] + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :] + labels = labels[:, 1:] + loss_fct = nn.CrossEntropyLoss(reduction=reduction) + lm_loss = loss_fct( + shifted_prediction_scores.reshape( + [-1, self.config.vocab_size]), labels.reshape([-1])) + if reduction == 'none': + lm_loss = lm_loss.reshape(prediction_scores.shape[0], + -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": + model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": + model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/ppcls/arch/ram/ram.py b/ppcls/arch/ram/ram.py new file mode 100644 index 0000000000..d24a98b91d --- /dev/null +++ b/ppcls/arch/ram/ram.py @@ -0,0 +1,511 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Code was based on https://github.com/xinyu1205/recognize-anything/tree/main +# reference: https://arxiv.org/abs/2306.03514 + +import yaml +import numpy as np + +from paddle import nn +import paddle +from paddlenlp.transformers import BertTokenizer +from paddle.nn import functional as F +from paddle.nn.initializer import Constant +from paddlenlp.transformers.bert.configuration import BertConfig + +from ..backbone.model_zoo.vision_transformer import VisionTransformer +from .bert import BertModel, BertLMHeadModel +from ..clip.clip import CLIP_DICT, tokenize +from ..clip.tokenizer import Tokenizer +from ..backbone.legendary_models.swin_transformer import SwinTransformer + + +class RamVis(VisionTransformer): + def forward_features(self, x): + return x + + +class RamSwin(SwinTransformer): + def forward_features(self, x): + x, output_dimensions = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x, output_dimensions = layer(x, output_dimensions) + + x = self.norm(x) # B L C + x_cls = self.avgpool(x.transpose([0, 2, 1])) # B C 1 + return paddle.concat([x_cls.transpose([0, 2, 1]), x], axis=1) + + def forward(self, x): + x = self.forward_features(x) + return x + + +def RamSwin_large_patch4_window12_384(): + return RamSwin( + img_size=384, + patch_size=4, + in_chans=3, + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + +def RamSwin_large_patch4_window7_224(): + return RamSwin( + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + +def RamSwin_base_patch4_window12_384(): + return RamSwin( + img_size=384, + patch_size=4, + in_chans=3, + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + +def RamSwin_base_patch4_window7_224(): + return RamSwin( + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + +CONFIG_PATH = 'ppcls' + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def build_text_embed(model_clip, tokenizer, caption): + with paddle.no_grad(): + texts = tokenize(caption, tokenizer) + text_embeddings = model_clip.encode_text(texts) + text_embeddings /= text_embeddings.norm(axis=-1, keepdim=True) + return text_embeddings + + +class AsymmetricLoss(nn.Layer): + def __init__(self, + gamma_neg=4, + gamma_pos=1, + clip=0.05, + eps=1e-8, + disable_torch_grad_focal_loss=True): + super(AsymmetricLoss, self).__init__() + + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss + self.eps = eps + + def forward(self, x, y): + """" + Parameters + ---------- + x: input logits + y: targets (multi-label binarized vector) + """ + + # Calculating Probabilities + x_sigmoid = nn.functional.sigmoid(x) + xs_pos = x_sigmoid + xs_neg = 1 - x_sigmoid + + # Asymmetric Clipping + if self.clip is not None and self.clip > 0: + xs_neg = (xs_neg + self.clip).clip(max=1) + + # Basic CE calculation + los_pos = y * paddle.log(xs_pos.clip(min=self.eps)) + los_neg = (1 - y) * paddle.log(xs_neg.clip(min=self.eps)) + loss = los_pos + los_neg + + # Asymmetric Focusing + if self.gamma_neg > 0 or self.gamma_pos > 0: + if self.disable_torch_grad_focal_loss: + paddle.set_grad_enabled(False) + pt0 = xs_pos * y + pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p + pt = pt0 + pt1 + one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) + one_sided_w = paddle.pow(1 - pt, one_sided_gamma.astype(pt.dtype)) + if self.disable_torch_grad_focal_loss: + paddle.set_grad_enabled(True) + loss *= one_sided_w + + return -loss.sum() + + +def init_tokenizer(tokenizer_name="bert-base-uncased"): + tokenizer = BertTokenizer.from_pretrained(tokenizer_name) + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + +def read_yaml(rpath): + with open(rpath, 'r') as f: + return yaml.safe_load(f) + + +def create_vit(vit, + image_size, + use_grad_checkpointing=False, + ckpt_layer=0, + drop_path_rate=0): + + assert vit in ['base', 'large'], "vit parameter must be base or large" + if vit == 'base': + vision_width = 768 + visual_encoder = RamVis( + img_size=image_size, + patch_size=16, + embed_dim=vision_width, + depth=12, + class_num=0, + num_heads=12, + use_grad_checkpointing=use_grad_checkpointing, + ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate) + elif vit == 'large': + vision_width = 1024 + visual_encoder = RamVis( + img_size=image_size, + patch_size=16, + class_num=0, + embed_dim=vision_width, + depth=24, + num_heads=16, + use_grad_checkpointing=use_grad_checkpointing, + ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate) + return visual_encoder, vision_width + + +class RAM(nn.Layer): + def __init__(self, + med_config='', + image_size=384, + vit='base', + vit_grad_ckpt=False, + vit_ckpt_layer=0, + prompt='', + clip_pretraind=None, + threshold=0.68, + delete_tag_index=[], + tag_list='', + tag_list_chinese='', + clip_version='', + q2l_config='', + ram_class_threshold_path='', + pretrained='', + stage='eval'): + + super().__init__() + + # create image encoder + self.stage = stage + if self.stage == 'train': + assert clip_pretraind + self.clip_tokenizer = Tokenizer() + assert clip_version in CLIP_DICT.keys( + ), 'please check the clip structure' + self.CLIP = CLIP_DICT[clip_version] + params = paddle.load(clip_pretraind) + self.CLIP.set_state_dict(params) + self.CLIP.eval() + + if vit == 'swin_b': + vision_width = 1024 + if image_size == 224: + self.visual_encoder = RamSwin_base_patch4_window7_224() + elif image_size == 384: + self.visual_encoder = RamSwin_base_patch4_window12_384() + + elif vit == 'swin_l': + vision_width = 1536 + if image_size == 224: + self.visual_encoder = RamSwin_large_patch4_window7_224() + elif image_size == 384: + self.visual_encoder = RamSwin_large_patch4_window12_384() + + else: + self.visual_encoder, vision_width = create_vit( + vit, image_size, vit_grad_ckpt, vit_ckpt_layer) + + # create tokenzier + self.tokenizer = init_tokenizer() + + # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder + # create image-tag interaction encoder + encoder_config = BertConfig.from_dict(read_yaml(med_config)) + encoder_config.encoder_width = 512 + self.tag_encoder = BertModel( + config=encoder_config, add_pooling_layer=False) + + # create image-tag-text decoder + decoder_config = BertConfig.from_dict(read_yaml(med_config)) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + self.delete_tag_index = delete_tag_index + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + # load tag list + self.tag_list = self.load_tag_list(tag_list) + self.tag_list_chinese = self.load_tag_list(tag_list_chinese) + + # create image-tag recognition decoder + self.threshold = threshold + self.num_class = len(self.tag_list) + q2l_config = BertConfig.from_dict(read_yaml(q2l_config)) + q2l_config.encoder_width = 512 + self.tagging_head = BertModel( + config=q2l_config, add_pooling_layer=False) + self.tagging_head.resize_token_embeddings(len(self.tokenizer)) + self.label_embed = self.create_parameter( + shape=(self.num_class, q2l_config.encoder_width), + default_initializer=Constant()) + + if q2l_config.hidden_size != 512: + self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) + else: + self.wordvec_proj = nn.Identity() + + self.fc = nn.Linear(q2l_config.hidden_size, 1) + + self.del_selfattention() + + self.tagging_loss_function = AsymmetricLoss( + gamma_neg=7, gamma_pos=0, clip=0.05) + + self.image_proj = nn.Linear(vision_width, 512) + + # adjust thresholds for some tags + self.class_threshold = paddle.ones([self.num_class]) * self.threshold + ram_class_threshold_path = ram_class_threshold_path + with open(ram_class_threshold_path, 'r', encoding='utf-8') as f: + ram_class_threshold = [float(s.strip()) for s in f] + for key, value in enumerate(ram_class_threshold): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, 'r', encoding='utf-8') as f: + tag_list = f.read().splitlines() + tag_list = np.array(tag_list) + return tag_list + + # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label + def del_selfattention(self): + del self.tagging_head.embeddings + for layer in self.tagging_head.encoder.layer: + del layer.attention + + def forward(self, + image_ram, + text=None, + image_tag=None, + tag_input_tokenzier=None, + image_clip=None): + """ + image-》 image_ram + image224 -> image_clip + call function as forward + + Args: + image: type: paddle.Tensor shape: batch_size * 3 * 384 * 384 + caption: type: paddle.Tensor len: batch_size * embedding_size + tag: type: paddle.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 + + Returns: + loss: type: paddle.Tensor + """ + assert self.stage == 'train' + label_embed = nn.functional.relu(self.wordvec_proj(self.label_embed)) + clip_feature = self.CLIP.encode_image(image_clip) + + image_embeds = self.image_proj(self.visual_encoder(image_ram)) + image_atts = paddle.ones( + paddle.shape(image_embeds)[:-1], dtype=paddle.int32) + + ##================= Distillation from CLIP ================## + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + loss_dis = 0. + if isinstance(clip_feature, paddle.Tensor): + loss_dis = F.l1_loss(image_cls_embeds, clip_feature) + +##================= Image Tagging ================## + bs = paddle.shape(image_embeds)[0] + #label_embed = paddle.repeat_interleave(label_embed.unsqueeze(0),[bs, 1, 1]) + label_embed = label_embed.unsqueeze(0).tile([bs, 1, 1]).squeeze(1) + + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + loss_tag = 0. + if isinstance(image_tag, paddle.Tensor): + loss_tag = self.tagging_loss_function(logits, image_tag) + + ##================= Image-Tag-Text Generation ================## + encoder_input_ids = tag_input_tokenzier.get("input_ids") + encoder_input_ids[:, 0] = self.tokenizer.enc_token_id + + # put input tag into image-tag interaction encoder to interact with image embeddings + output_tagembedding = self.tag_encoder( + encoder_input_ids, + attention_mask=tag_input_tokenzier.get("attention_mask"), + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, ) + + decoder_input_ids = text.get("input_ids") + decoder_input_ids[:, 0] = self.tokenizer.enc_token_id + + decoder_targets = masked_fill( + decoder_input_ids, + decoder_input_ids == self.tokenizer.pad_token_id, -100) + decoder_targets[:, :self.prompt_length] = -100 + + decoder_output = self.text_decoder( + decoder_input_ids, + attention_mask=text.get("attention_mask"), + encoder_hidden_states=output_tagembedding.last_hidden_state, + encoder_attention_mask=None, + labels=decoder_targets, + return_dict=True, ) + + loss_t2t = decoder_output.loss + + return loss_t2t, loss_tag, loss_dis + + # to support paddle framework + def inference( + self, + image, + threshold=0.4, + tag_input=None, ): + + label_embed = F.relu(self.wordvec_proj(self.label_embed)) + + image_embeds = self.image_proj(self.visual_encoder(image)) + image_atts = paddle.ones(image_embeds.shape[:-1], dtype=paddle.int32) + + # recognized image tags using image-tag recogntiion decoder + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = label_embed.unsqueeze(0).tile([bs, 1, 1]).squeeze(1) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + + return logits + + def generate_tag_openset( + self, + image, + threshold=0.68, + tag_input=None, ): + + label_embed = F.relu(self.wordvec_proj(self.label_embed)) + + image_embeds = self.image_proj(self.visual_encoder(image)) + image_atts = paddle.ones( + [image_embeds.size()[:-1]], dtype=paddle.int32) + + # recognized image tags using image-tag recogntiion decoder + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = label_embed.unsqueeze(0).repeat([bs, 1, 1]) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + + return logits + + +# load RAM pretrained model parameters +def ram(pretrained='', **kwargs): + model = RAM(pretrained='', **kwargs) + return model diff --git a/ppcls/arch/ram/ram_plus.py b/ppcls/arch/ram/ram_plus.py new file mode 100644 index 0000000000..695e529720 --- /dev/null +++ b/ppcls/arch/ram/ram_plus.py @@ -0,0 +1,249 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Code was based on https://github.com/xinyu1205/recognize-anything/tree/main +# reference: https://arxiv.org/abs/2310.15200 + +import paddle +import numpy as np +from paddle import nn +from paddle.nn import functional as F +from paddle.nn.initializer import Constant + +from .bert import BertModel, BertLMHeadModel +from ..clip.clip import tokenize +from .ram import RAM, AsymmetricLoss + + +def build_text_embed(model_clip, texts): + with paddle.no_grad(): + text_embeddings = model_clip.encode_text(texts) + text_embeddings /= text_embeddings.norm(axis=-1, keepdim=True) + return text_embeddings + + +class RAM_plus(RAM): + def __init__(self, + med_config='', + image_size=None, + vit='', + vit_grad_ckpt=False, + vit_ckpt_layer=0, + prompt='', + threshold=0.68, + delete_tag_index=[], + clip_pretraind='', + tag_list='', + tag_list_chinese='', + clip_version='', + q2l_config='', + ram_class_threshold_path='', + pretrained='', + stage='eval'): + + super().__init__( + med_config=med_config, + image_size=image_size, + vit=vit, + vit_grad_ckpt=vit_grad_ckpt, + vit_ckpt_layer=vit_ckpt_layer, + prompt=prompt, + threshold=threshold, + delete_tag_index=delete_tag_index, + clip_pretraind=clip_pretraind, + tag_list=tag_list, + clip_version=clip_version, + tag_list_chinese=tag_list_chinese, + q2l_config=q2l_config, + ram_class_threshold_path=ram_class_threshold_path, + stage=stage) + + self.label_embed = self.create_parameter( + shape=(self.num_class * 51, 512), default_initializer=Constant()) + self.reweight_scale = self.create_parameter( + shape=(1, ), default_initializer=Constant(1. * np.log(1 / 0.07))) + self.text_alignment_loss_function = AsymmetricLoss( + gamma_neg=4, gamma_pos=0, clip=0.05) + + def forward(self, image_ram, caption, image_tag, parse_tag, + imageclip=None): + """ + call function as forward + + Args: + image_ram: type: paddle.Tensor shape: batch_size * 3 * 384 * 384 + caption: type: list[string] len: batch_size + image_tag: type: paddle.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 + parse_tag: text for image_tag + imageclip = image for clip encoder + + Returns: + loss: type: paddle.Tensor + """ + assert self.stage == 'train' + clip_feature = self.CLIP.encode_image(imageclip) + batch_text_embed = build_text_embed(self.CLIP, caption) + image_embeds = self.image_proj(self.visual_encoder(image_ram)) + image_atts = paddle.ones(image_embeds.shape[:-1], dtype=paddle.int32) + ##================= Distillation from CLIP ================## + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + loss_dis = F.l1_loss(image_cls_embeds, clip_feature) + + ##================= Image Tagging ================## + bs = paddle.shape(image_embeds)[0] + des_per_class = int(self.label_embed.shape[0] / self.num_class) + image_cls_embeds = image_cls_embeds / image_cls_embeds.norm( + axis=-1, keepdim=True) + reweight_scale = self.reweight_scale.exp() + logits_per_image = (reweight_scale * image_cls_embeds + @self.label_embed.t()) + logits_per_image = logits_per_image.reshape([bs, -1, des_per_class]) + weight_normalized = nn.functional.softmax(logits_per_image, axis=2) + label_embed_reweight = paddle.empty([bs, self.num_class, 512]) + for i in range(bs): + reshaped_value = self.label_embed.reshape([-1, des_per_class, 512]) + product = weight_normalized[i].unsqueeze(-1) * reshaped_value + label_embed_reweight[i] = product.sum(axis=1) + + label_embed = nn.functional.relu( + self.wordvec_proj(label_embed_reweight)) + + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + loss_tag = self.tagging_loss_function(logits, image_tag) + + ##================= Image-text Alignment ================## + + batch_text_embed = F.relu( + self.wordvec_proj( + batch_text_embed.astype(self.label_embed.dtype))) + batch_text_embed = batch_text_embed.unsqueeze(0).tile([bs, 1, 1]) + alignment_embedding = self.tagging_head( + encoder_embeds=batch_text_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', ) + alignment_logits = self.fc(alignment_embedding[0]).squeeze(-1) + + with paddle.no_grad(): + alignment_targets = paddle.zeros(alignment_logits.shape) + alignment_targets.fill_diagonal_(1) + + loss_alignment = self.text_alignment_loss_function(alignment_logits, + alignment_targets) + + return loss_tag, loss_dis, loss_alignment + + # to support paddle framework + def inference(self, image): + + image_embeds = self.image_proj(self.visual_encoder(image)) + image_atts = paddle.ones(image_embeds.shape[:-1], dtype=paddle.int32) + + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + + des_per_class = int(self.label_embed.shape[0] / self.num_class) + + image_cls_embeds = image_cls_embeds / image_cls_embeds.norm( + axis=-1, keepdim=True) + reweight_scale = self.reweight_scale.exp() + logits_per_image = (reweight_scale * image_cls_embeds + @self.label_embed.t()) + logits_per_image = logits_per_image.reshape([bs, -1, des_per_class]) + + weight_normalized = F.softmax(logits_per_image, axis=2) + label_embed_reweight = paddle.empty([bs, self.num_class, 512]) + + for i in range(bs): + # boardingcast + reshaped_value = self.label_embed.reshape([-1, des_per_class, 512]) + product = weight_normalized[i].unsqueeze(-1) * reshaped_value + label_embed_reweight[i] = product.sum(axis=1) + + label_embed = F.relu(self.wordvec_proj(label_embed_reweight)) + + # recognized image tags using alignment decoder + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + + return logits + + def generate_tag_openset( + self, + image, + threshold=0.68, + tag_input=None, ): + + image_embeds = self.image_proj(self.visual_encoder(image)) + image_atts = paddle.ones(image_embeds.shape[:-1], dtype=paddle.int32) + + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + + des_per_class = int(self.label_embed.shape[0] / self.num_class) + + image_cls_embeds = image_cls_embeds / image_cls_embeds.norm( + axis=-1, keepdim=True) + reweight_scale = self.reweight_scale.exp() + logits_per_image = (reweight_scale * image_cls_embeds + @self.label_embed.t()) + logits_per_image = logits_per_image.reshape([bs, -1, des_per_class]) + + weight_normalized = F.softmax(logits_per_image, axis=2) + label_embed_reweight = paddle.empty([bs, self.num_class, 512]) + + for i in range(bs): + # boardingcast + reshaped_value = self.label_embed.reshape([-1, des_per_class, 512]) + product = weight_normalized[i].unsqueeze(-1) * reshaped_value + label_embed_reweight[i] = product.sum(axis=1) + + label_embed = F.relu(self.wordvec_proj(label_embed_reweight)) + + # recognized image tags using alignment decoder + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + + return logits + + +# load RAM pretrained model parameters +def ram_plus(pretrained='', **kwargs): + model = RAM_plus(pretrained='', **kwargs) + return model diff --git a/ppcls/configs/RAM/CLIP.yaml b/ppcls/configs/RAM/CLIP.yaml new file mode 100644 index 0000000000..6cadf224ae --- /dev/null +++ b/ppcls/configs/RAM/CLIP.yaml @@ -0,0 +1,124 @@ +# global configs +Global: + checkpoints: null + pretrained_model: "ViT-B-32.pdparams" # pretrain model for ram and ram plus, default random initilize + output_dir: ./output/ + device: cpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 1 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference_text + + + +# model architecture +Arch: + name: CLIP_vit_base_patch32_224_with_TextEncoder + clip: "text" + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: RAMPretrainDataset + ann_file: [./visual-genome/vg_ram.json] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + model_name: "RAM" + transform_ops_ram: + - ResizeImage: + size: 384 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + transform_ops_clip: + - ResizeImage: + size: 224 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 52 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: RAMPretrainDataset + ann_file: [./visual-genome/vg_ram.json] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + model_name: "RAM" + transform_ops_ram: + - ResizeImage: + size: 384 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + transform_ops_clip: + - ResizeImage: + size: 224 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 52 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 1 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 224 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: RamOutPut + language: "cn" + tag_list: "ppcls/utils/ram/ram_tag_list.txt" + tag_list_chinese: "ppcls/utils/ram/ram_tag_list_chinese.txt" + ram_class_threshold_path: "ppcls/utils/ram/ram_tag_list_threshold.txt" diff --git a/ppcls/configs/ram/RAM.yaml b/ppcls/configs/ram/RAM.yaml new file mode 100644 index 0000000000..a33581140f --- /dev/null +++ b/ppcls/configs/ram/RAM.yaml @@ -0,0 +1,172 @@ +# global configs +Global: + checkpoints: null + pretrained_model: "ram.pdparams" # pretrain model for ram and ram plus, default random initilize + output_dir: ./output/ + device: cpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 1 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 384, 384] + save_inference_dir: ./inference + + +# mixed precision +AMP: + use_amp: True + use_fp16_test: False + scale_loss: 128.0 + use_dynamic_loss_scaling: True + use_promote: False + # O1: mixed fp16, O2: pure fp16 + level: O1 + + +# model architecture +Arch: + name: ram + vit: swin_l + vit_grad_ckpt: False + vit_ckpt_layer: 0 + image_size: 384 + prompt: 'a picture of ' + med_config: 'ppcls/configs/ram/RAM_bert.yaml' + delete_tag_index: [] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + tag_list_chinese: 'ppcls/utils/ram/ram_tag_list_chinese.txt' + clip_pretraind: ./ViT-B-32.pdparams #for CLIP a necessary part for training ram + clip_version: 'vit-b-32-224' + q2l_config: 'ppcls/configs/ram/RAM_q2l.yaml' + ram_class_threshold_path: 'ppcls/utils/ram/ram_tag_list_threshold.txt' + stage: train + threshold: 0.68 + +# loss function config for traing/eval process +Loss: + Train: + - RAMLoss: + weight: 1.0 + Eval: + - RAMLoss: + weight: 1.0 + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + layerwise_decay: 0.6 + filter_bias_and_bn: True + lr: + name: Cosine + learning_rate: 0.0004 + eta_min: 1e-6 + warmup_epoch: 10 + warmup_start_lr: 5e-7 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: RAMPretrainDataset + ann_file: [./visual-genome/vg_ram.json] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + model_name: "RAM" + transform_ops_ram: + - ResizeImage: + size: 384 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + transform_ops_clip: + - ResizeImage: + size: 224 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 52 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: RAMPretrainDataset + ann_file: [./visual-genome/vg_ram.json] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + model_name: "RAM" + transform_ops_ram: + - ResizeImage: + size: 384 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + transform_ops_clip: + - ResizeImage: + size: 224 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 52 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 1 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: RamOutPut + language: "cn" + tag_list: "ppcls/utils/ram/ram_tag_list.txt" + tag_list_chinese: "ppcls/utils/ram/ram_tag_list_chinese.txt" + ram_class_threshold_path: "ppcls/utils/ram/ram_tag_list_threshold.txt" diff --git a/ppcls/configs/ram/RAM_bert.yaml b/ppcls/configs/ram/RAM_bert.yaml new file mode 100644 index 0000000000..18287a458e --- /dev/null +++ b/ppcls/configs/ram/RAM_bert.yaml @@ -0,0 +1,20 @@ +architectures: [ + 'BertModel' + ] #default uncased + +attention_probs_dropout_prob: 0.1 # attention dropout rata +hidden_act: gelu # acitvation +hidden_dropout_prob: 0.1 +hidden_size: 768 # embedding size +initializer_range: 0.02 +intermediate_size: 3072 +layer_norm_eps: 1.0e-12 +max_position_embeddings: 512 +model_type: bert +num_attention_heads: 12 +num_hidden_layers: 12 +pad_token_id: 0 +type_vocab_size: 2 +vocab_size: 30524 # text embedding size +encoder_width: 768 +add_cross_attention: true \ No newline at end of file diff --git a/ppcls/configs/ram/RAM_plus.yaml b/ppcls/configs/ram/RAM_plus.yaml new file mode 100644 index 0000000000..12fcbd1ff4 --- /dev/null +++ b/ppcls/configs/ram/RAM_plus.yaml @@ -0,0 +1,176 @@ +# global configs +Global: + checkpoints: null + pretrained_model: "ram_plus.pdparams" # pretrain model for ram and ram plus, default random initilize + output_dir: ./output/ + device: cpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 1 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 384, 384] + save_inference_dir: ./inference + + +# mixed precision +AMP: + use_amp: True + use_fp16_test: False + scale_loss: 128.0 + use_dynamic_loss_scaling: True + use_promote: False + # O1: mixed fp16, O2: pure fp16 + level: O1 + + +# model architecture +Arch: + name: ram_plus + vit: swin_l + vit_grad_ckpt: False + vit_ckpt_layer: 0 + image_size: 384 + prompt: 'a picture of ' + med_config: 'ppcls/configs/ram/RAM_bert.yaml' + delete_tag_index: [] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + tag_list_chinese: 'ppcls/utils/ram/ram_tag_list_chinese.txt' + clip_pretraind: ./ViT-B-32.pdparams #for CLIP a necessary part for training ram + clip_version: 'vit-b-32-224' + q2l_config: 'ppcls/configs/ram/RAM_q2l.yaml' + ram_class_threshold_path: 'ppcls/utils/ram/ram_tag_list_threshold.txt' + stage: train + threshold: 0.68 + + +# loss function config for traing/eval process +Loss: + Train: + - RAMLoss: + weight: 1.0 + Eval: + - RAMLoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + layerwise_decay: 0.6 + filter_bias_and_bn: True + lr: + name: Cosine + learning_rate: 0.0004 + eta_min: 1e-6 + warmup_epoch: 10 + warmup_start_lr: 5e-7 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: RAMPretrainDataset + ann_file: [./visual-genome/vg_ram.json] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + model_name: "RAM_plus" + transform_ops_ram: + - ResizeImage: + size: 384 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + transform_ops_clip: + - ResizeImage: + size: 224 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 52 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: RAMPretrainDataset + ann_file: [./visual-genome/vg_ram.json] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + model_name: "RAM_plus" + transform_ops_ram: + - ResizeImage: + size: 384 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + transform_ops_clip: + - ResizeImage: + size: 224 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 52 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 1 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: RamOutPut + language: "cn" + tag_list: "ppcls/utils/ram/ram_tag_list.txt" + tag_list_chinese: "ppcls/utils/ram/ram_tag_list_chinese.txt" + ram_class_threshold_path: "ppcls/utils/ram/ram_tag_list_threshold.txt" + diff --git a/ppcls/configs/ram/RAM_q2l.yaml b/ppcls/configs/ram/RAM_q2l.yaml new file mode 100644 index 0000000000..d1e4598230 --- /dev/null +++ b/ppcls/configs/ram/RAM_q2l.yaml @@ -0,0 +1,21 @@ +architectures: [ + 'BertModel' + ] #default uncased + +attention_probs_dropout_prob: 0.1 # attention dropout rata +hidden_act: gelu # acitvation +hidden_dropout_prob: 0.1 +hidden_size: 768 # embedding size +initializer_range: 0.02 +intermediate_size: 3072 +layer_norm_eps: 1.0e-12 +max_position_embeddings: 512 +model_type: bert +num_attention_heads: 4 +num_hidden_layers: 2 +pad_token_id: 0 +type_vocab_size: 2 +vocab_size: 30524 # text embedding size +encoder_width: 768 +add_cross_attention: true # whether using cross-attention default False +add_tag_cross_attention : false \ No newline at end of file diff --git a/ppcls/configs/ram/README.md b/ppcls/configs/ram/README.md new file mode 100644 index 0000000000..d3900086e6 --- /dev/null +++ b/ppcls/configs/ram/README.md @@ -0,0 +1,390 @@ +# RAM, RAM++ 图文标签模型 + +## 目录 + +* [1. 模型介绍](#1) +* [2. 数据和模型准备](#2) +* [3. 模型训练](#3) +* [4. 模型评估](#4) +* [5. 模型预测](#5) +* [6. 基于预测引擎预测](#6) + * [6.1 导出 inference model](#6.1) + * [6.2 基于 Python 预测引擎推理](#6.2) +* [7. 引用](#7) + + +## 1. 模型介绍 + +RAM以及RAM++(下文简称RAM类模型)主要用于标注类任务,其中两个模型的主要贡献为提出了集训练-推理-tag一体化的框架。其通过堆叠vision encoder以及text encoder,实现多种下游任务。核心方法包括: + +1. 结合CLIP架构,提出 Image-Tag Recognition Decoder,Image-Text Alignment Encoder,Image-Tag Interaction Encoder,Image-Tag-Text Generation Decoder以及Generation Encoder 5个组件分别实现text-image对齐,text-tag对齐。 +2. RAM++进一步使用大语言模型(large language model,LLM)的语义信息,提升text-image对齐的能力。 + +使用RAM类模型时,作者在多个分类任务上取得了最先进的结果: + +| Model | BackBone | Store Size | Inference Prompt | CLIP | OpenImages-MAP | +|-------|------------|--------|------------------|------|----------------| +| RAM | Swin-large | 5.63GB | LLM Tag Dec | VIT-base-patch16-224 | 82.2 | +| RAM++ | Swin-base | 3.01GB | LLM Tag Dec | VIT-base-patch16-224 | 86.6 | +注:LLM Tag Dec表示基于LLM改写的文本tag。例如给定prompt:"A photo of a cat" 对应LLM tag Dec为:"Cat is a small general with sofa". + +`PaddleClas` paddleclas分别实现了基于不同backbone的RAM类模型: +```yaml +# model architecture +Arch: + name: ram_plus + vit: swin_l + vit_grad_ckpt: False + vit_ckpt_layer: 0 + image_size: 384 + prompt: 'a picture of ' + med_config: 'ppcls/configs/ram/ram_bert.yaml' + delete_tag_index: [] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + tag_list_chinese: 'ppcls/utils/ram/ram_tag_list_chinese.txt' + clip_pretraind: ./ViT-B-32.pdparams #for CLIP a necessary part for training ram + clip_version: 'vit-b-32-224' + q2l_config: 'ppcls/configs/ram/ram_q2l.yaml' + ram_class_threshold_path: 'ppcls/utils/RAM/ram_tag_list_threshold.txt' + stage: train + threshold: 0.68 +``` +参数注释: + - name 模型参数,使用RAM模型可以指定为ram,使用RAM++模型可以指定为ram_plus,默认为ram + - vit 视觉主干网络参数,包括 vit:vision transformer,swin_b: swin base模型,swin_l, swin large模型 + - image_size 图片分辨率 + - prompt RAM训练时所使用的文本提示前缀 + - med_config RAM类模型所使用的Bert模型配置文件,默认配置路径:'ppcls/configs/ram/config_bert.yaml' + - delete_tag_index 屏蔽tag所用参数,例如传递[1,3,2]则表示屏蔽index为1,2,3的tag标签 + - tag_list 英文tag标签文件路劲,默认ppcls/utils/RAM/ram_tag_list.txt + - tag_list_chinese 中文tag标签文件路劲,默认ppcls/utils/RAM/ram_tag_list.txt + - clip_version 所使用的CLIP结构,默认 vit-b-32-224 + - clip_pretraind 训练所使用的CLIP预训练参数路径,当需要训练RAM类模型时,不能为None + - q2l_config 基于bert 的text-tag alignment encoder模型配置文件默认 ppcls/configs/ram/config_q2l.yaml + - ram_class_threshold_path tag生成阈值文件默认ppcls/utils/RAM/ram_tag_list_threshold.txt + - stage 指定RAM,RAM++模型是否进行训练,stage = train表示需要训练,训练时clip_pretraind不能为None,stage = eval表示无需训练 + - threshold 输出TAG所需的阈值数值,表示当该tag对应概率大于该值,则认为属于该tag +注意,RAM类模型的推理和训练,需要使用tools/train_multimodal.py, tools/infer_multimodal.py 以及predict_multimodal.py接口,支持多模态输入的动态图训练,推理以及静态图推理。 + + +## 2. 数据和模型准备 + +* 前往官方[repo](https://github.com/xinyu1205/recognize-anything/tree/main)下载对应数据集json文件。同时按照json文件目录格式,准备相应的数据。目录格式为: +```json +{ + { + "image_path": "visual-genome/VG_100K/1.jpg", + "caption": ["trees line the sidewalk"], + "union_label_id": [4480], + "parse_label_id": [[4253, 2461, 2966]] + } +} +``` +参数注释: + - image_path 数据集路径 + - caption 对应图片标注 + - union_label_id 标注对应id + - parse_label_id 将标注仅需名词化后,结果对应的id。例如将"trees line the sidewalk"名词化得到 "trees" "line"以及"sidewalk",其对应的id分别是4253, 2461, 2966 +其中务必保证数据集文件路径符合image_path +* 本文档中,为RAM类模型提供了统一的动态训练以及推理配置文件,结构如下: +```yaml +# global configs +Global: + checkpoints: null + pretrained_model: "ram.pdparams" # pretrain model for ram and ram plus, default random initilize + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 384, 384] + save_inference_dir: ./inference + + +# mixed precision +AMP: + use_amp: True + use_fp16_test: False + scale_loss: 128.0 + use_dynamic_loss_scaling: True + use_promote: False + # O1: mixed fp16, O2: pure fp16 + level: O1 + + +# model architecture +Arch: + name: ram + vit: swin_l + vit_grad_ckpt: False + vit_ckpt_layer: 0 + image_size: 384 + prompt: 'a picture of ' + med_config: 'ppcls/configs/ram/ram_bert.yaml' + delete_tag_index: [] + tag_list: 'ppcls/utils/ram/ram_tag_list.txt' + tag_list_chinese: 'ppcls/utils/ram/ram_tag_list_chinese.txt' + clip_pretraind: ./ViT-B-32.pdparams #for CLIP a necessary part for training ram + clip_version: 'vit-b-32-224' + q2l_config: 'ppcls/configs/ram/ram_q2l.yaml' + ram_class_threshold_path: 'ppcls/utils/RAM/ram_tag_list_threshold.txt' + stage: train + threshold: 0.68 + +# loss function config for traing/eval process +Loss: + Train: + - RAMLoss: + weight: 1.0 + Eval: + - RAMLoss: + weight: 1.0 + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + layerwise_decay: 0.6 + filter_bias_and_bn: True + lr: + name: Cosine + learning_rate: 0.0004 + eta_min: 1e-6 + warmup_epoch: 10 + warmup_start_lr: 5e-7 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: RAMPretrainDataset + ann_file: [./visual-genome/vg_ram.json] + transform_ops_ram: + - ResizeImage: + size: 384 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + transform_ops_clip: + - ResizeImage: + resize_short: 224 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 52 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: RAMPretrainDataset + ann_file: [./visual-genome/vg_ram.json] + transform_ops_ram: + - ResizeImage: + size: 384 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + transform_ops_clip: + - ResizeImage: + resize_short: 224 + interpolation: bicubic + backend: pil + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 52 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 1 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: RamOutPut + language: "cn" + tag_list: "ppcls/utils/RAM/ram_tag_list.txt" + tag_list_chinese: "ppcls/utils/RAM/ram_tag_list_chinese.txt" + ram_class_threshold_path: "ppcls/utils/RAM/ram_tag_list_threshold.txt" + + +``` +用户可以根据自身需求,更改相应配置。注意arch参数请参照本文档。 + +## 3. 模型训练 +以RAM为例: +```shell +# 多卡 +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python3 -m paddle.distributed.launch \ + --gpus="0,1,2,3" \ + tools/train_multimodal.py \ + -c ./ppcls/configs/ram/RAM.yaml +# 单卡 +python3 tools/train_multimodal.py \ + -c ./ppcls/configs//ram/RAM.yaml +``` +以RAM++为例: +```shell +# 多卡 +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python3 -m paddle.distributed.launch \ + --gpus="0,1,2,3" \ + tools/train_multimodal.py \ + -c ./ppcls/configs/ram/RAM_plus.yaml +# 单卡 +python3 tools/train_multimodal.py \ + -c ./ppcls/configs//ram/RAM_plus.yaml +``` + + + +## 4. 模型预测 + +```bash +python3 tools/infer_multimodal.py \ + -c ./ppcls/configs/ram/RAM.yaml \ + -o Global.pretrained_model="./output/ram/best_model" +``` + +得到类似下面的输出: +``` +{'class_ids': [[[593], [871], [998], [2071], [3336], [3862], [4389]]], 'scores': [[[0.9708361625671387], [0.9998403787612915], [0.9122695922851562], +[0.8888279795646667], [0.8671568036079407], [0.8900104761123657], [0.811939001083374]]], 'label_names': ['棕色 | 鸡 | 公鸡 | 母鸡 | 红色 | 站/矗立/摊位 | 走 ']} +``` + + +## 5. 基于预测引擎预测 + + +### 5.1 导出 inference model + +```bash +python3 tools/export_model.py \ + -c ./ppcls/configs/ram/RAM.yaml \ + -o Global.pretrained_model="./output/ram/" +``` +inference model 的路径默认在当前路径下 `./inference` +`./inference` 文件夹下应有如下文件结构: + +``` +├── inference +│ ├── inference.pdiparams +│ ├── inference.pdiparams.info +│ └── inference.pdmodel +``` + + + +### 5.2 基于 Python 预测引擎推理 + +切换到depoly目录下,并且使用deploy中的脚本进行推理前需要确认paddleclas为非本地安装, 如不是请进行切换,不然会出现包的导入错误。 + +```shell +# 本地安装 +pip install -e . +# 非本地安装 +python setup.py install + +# 进入deploy目录下 +cd deploy +``` + + + +#### 5.2.1 预测单张图像 + +运行下面的命令,对图像 `docs/images/inference_deployment/whl_demo.jpg` 进行分类。 + +```shell +# linux使用`python3`,windows使用`python (-m)`来执行脚本 +# 使用下面的命令使用 GPU 进行预测 +python3 python/predict_multimodal.py \ + -c deploy/configs/inference_ram.yaml \ + -o Global.inference_model_dir=../inference/ \ + -o Global.infer_imgs=docs/images/inference_deployment/whl_demo.jpg +# 使用下面的命令使用 CPU 进行预测 +#更改 `config_infer.yaml` 配置文件后 +python3 python/predict_multimodal.py \ + -c deploy/configs/inference_ram.yaml \ + -o Global.inference_model_dir=../inference/ \ + -o Global.infer_imgs=docs/images/inference_deployment/whl_demo.jpg +``` + +输出结果如下: + +``` +whl_demo.jpg-class_ids: [[[593], [871], [998], [2071], [3336], [3862], [4389]]], +whl_demo.jpg-scores: [[[0.9708361625671387], [0.9998403787612915], [0.9122695922851562], +[0.8888279795646667], [0.8671568036079407], [0.8900104761123657], [0.811939001083374]]], +whl_demo.jpg-label_names: ['棕色 | 鸡 | 公鸡 | 母鸡 | 红色 | 站/矗立/摊位 | 走 '] +``` + + + +## 6. 引用 +``` +@article{huang2023inject, + title={Inject Semantic Concepts into Image Tagging for Open-Set Recognition}, + author={Huang, Xinyu and Huang, Yi-Jie and Zhang, Youcai and Tian, Weiwei and Feng, Rui and Zhang, Yuejie and Xie, Yanchun and Li, Yaqian and Zhang, Lei}, + journal={arXiv preprint arXiv:2310.15200}, + year={2023} +} + +@article{zhang2023recognize, + title={Recognize Anything: A Strong Image Tagging Model}, + author={Zhang, Youcai and Huang, Xinyu and Ma, Jinyu and Li, Zhaoyang and Luo, Zhaochuan and Xie, Yanchun and Qin, Yuzhuo and Luo, Tong and Li, Yaqian and Liu, Shilong and others}, + journal={arXiv preprint arXiv:2306.03514}, + year={2023} +} +``` \ No newline at end of file diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index baeed665d0..492b40f959 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -37,6 +37,7 @@ from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset from ppcls.data.dataloader.cifar import Cifar10, Cifar100 from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler +from ppcls.data.dataloader.ram_dataset import RAMPretrainDataset # sampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler diff --git a/ppcls/data/dataloader/__init__.py b/ppcls/data/dataloader/__init__.py index 391dcef65b..ff62485025 100644 --- a/ppcls/data/dataloader/__init__.py +++ b/ppcls/data/dataloader/__init__.py @@ -14,3 +14,4 @@ from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset from ppcls.data.dataloader.cifar import Cifar10, Cifar100 from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler +from ppcls.data.dataloader.ram_dataset import RAMPretrainDataset diff --git a/ppcls/data/dataloader/ram_dataset.py b/ppcls/data/dataloader/ram_dataset.py new file mode 100644 index 0000000000..782949fb97 --- /dev/null +++ b/ppcls/data/dataloader/ram_dataset.py @@ -0,0 +1,163 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +import numpy as np + +from paddle.io import Dataset +import paddle +from .common_dataset import create_operators +from ppcls.data.preprocess import transform +from ppcls.arch.clip.tokenizer import Tokenizer +from ppcls.arch.clip.clip import tokenize +from ppcls.arch.ram.ram import init_tokenizer + +from PIL import Image +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True +Image.MAX_IMAGE_PIXELS = None + + +def pre_caption(caption, max_words=50): + caption = re.sub( + r"([.!\"()*#:;~])", + ' ', + caption.lower(), ) + caption = re.sub( + r"\s{2,}", + ' ', + caption, ) + caption = caption.rstrip('\n') + caption = caption.strip(' ') + + #truncate caption + caption_words = caption.split(' ') + if len(caption_words) > max_words: + caption = ' '.join(caption_words[:max_words]) + + return caption + + +class RAMPretrainDataset(Dataset): + def __init__(self, + ann_file, + width=384, + class_num=4585, + root='', + tag_list='', + model_name='', + transform_ops_ram=None, + transform_ops_clip=None): + + self.ann = [] + for f in ann_file: + print('loading ' + f) + ann = json.load(open(f, 'r')) + self.ann += ann + self.width = width + self.name = model_name.lower() + self.transform_clip = create_operators(transform_ops_clip) + self.transform = create_operators(transform_ops_ram) + self.class_num = class_num + self.root = root + self.tag_list = self.load_tag_list(tag_list) + self.bert_tokenizer = init_tokenizer() + self.clip_tokenizer = Tokenizer() + + + def __len__(self): + return len(self.ann) + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, 'r', encoding='utf-8') as f: + tag_list = f.read().splitlines() + tag_list = np.array(tag_list) + return tag_list + + def collect_fn_list(self, data): + image_ram_list = [] + text_list = [] + image_tag_list = [] + image_parse_tag_list = [] + image_clip_list = [] + for item in data: + i1,i2,i3,i4,i5 = item + image_ram_list.append(i1) + text_list.append(i2) + image_tag_list.append(i3) + image_parse_tag_list.append(i4) + image_clip_list.append(i5) + image_rams = paddle.stack(image_ram_list) + if self.name == "ram": + text_list = self.bert_tokenizer( + text_list, + padding='longest', + truncation=True, + max_length=40, + return_attention_mask=True, + return_tensors='pd') + else: + text_list = tokenize(text_list, self.clip_tokenizer) + image_tags = paddle.stack(image_tag_list) + image_parse_tag_list = np.stack(image_parse_tag_list) + tag_input = [] + for b in range(len(image_ram_list)): + index = np.argwhere(image_parse_tag_list[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + tag_input.append(' | '.join(token)) + image_parse_tags = self.bert_tokenizer( + tag_input, + padding='max_length', + truncation=True, + max_length=40, + return_attention_mask=True, + return_tensors='pd') + image_clips = paddle.stack(image_clip_list) + return [image_rams, text_list , image_tags, image_parse_tags, image_clips] + + def __getitem__(self, index): + + ann = self.ann[index] + + image_path_use = os.path.join(self.root, ann['image_path']) + + try: + image = Image.open(image_path_use).convert('RGB') + image_ram = transform(image, self.transform) + image_ram = paddle.to_tensor(image_ram) + image224 = transform(image, self.transform_clip) + image224 = paddle.to_tensor(image224) + except: + image224 = paddle.ones([3, 224, 224]) + image_ram = paddle.ones([3, self.width, self.width]) + + num = ann['union_label_id'] + image_tag = np.zeros([self.class_num]) + image_tag[num] = 1 + image_tag = paddle.to_tensor(image_tag, dtype=paddle.int32) + + caption_index = np.random.randint(0, len(ann['caption'])) + + caption = pre_caption(ann['caption'][caption_index], 30) + + num = ann['parse_label_id'][caption_index] + parse_tag = np.zeros([self.class_num]) + parse_tag[num] = 1 + + return [image_ram, caption, image_tag, parse_tag, image224] + + diff --git a/ppcls/data/postprocess/__init__.py b/ppcls/data/postprocess/__init__.py index 202f5be8ba..a86c3bbfe8 100644 --- a/ppcls/data/postprocess/__init__.py +++ b/ppcls/data/postprocess/__init__.py @@ -20,6 +20,7 @@ from .threshoutput import ThreshOutput, MultiLabelThreshOutput from .attr_rec import VehicleAttribute, PersonAttribute, TableAttribute from .scoreoutput import ScoreOutput +from .ramoutput import RamOutPut def build_postprocess(config): diff --git a/ppcls/data/postprocess/ramoutput.py b/ppcls/data/postprocess/ramoutput.py new file mode 100644 index 0000000000..9297d39fa5 --- /dev/null +++ b/ppcls/data/postprocess/ramoutput.py @@ -0,0 +1,94 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np + +import paddle +import paddle.nn.functional as F + + +class RamOutPut(object): + def __init__(self, + language="cn", + tag_list="", + tag_list_chinese="", + threshold=0.68, + ram_class_threshold_path="", + delete_tag_index=[]): + """ + """ + self.language = language + assert tag_list, tag_list_chinese + self.tag_list = self.load_tag_list(tag_list) + self.delete_tag_index = delete_tag_index + self.tag_list_chinese = self.load_tag_list(tag_list_chinese) + self.num_class = len(self.tag_list) + self.class_threshold = paddle.ones([self.num_class]) * threshold + ram_class_threshold_path = ram_class_threshold_path + with open(ram_class_threshold_path, "r", encoding="utf-8") as f: + ram_class_threshold = [float(s.strip()) for s in f] + for key, value in enumerate(ram_class_threshold): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, "r", encoding="utf-8") as f: + tag_list = f.read().splitlines() + tag_list = np.array(tag_list) + return tag_list + + def __call__(self, logits, file_names=None): + """ + logits is the result from model + bs is the batch size from model + file_names is useless but need in order to fit support framework of ppcls + """ + bs = len(file_names) + targets = paddle.where( + F.sigmoid(logits) > self.class_threshold, + paddle.to_tensor([1.0]), paddle.zeros(self.num_class)) + targets = targets.reshape([bs, -1]) + res = {} + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + tag_output = [] + tag_output_chinese = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + tag_output.append(" | ".join(token)) + token_chinese = self.tag_list_chinese[index].squeeze(axis=1) + tag_output_chinese.append(" | ".join(token_chinese)) + res["cn"] = tag_output_chinese + res["en"] = tag_output + res["all"] = f"en : {tag_output}, cn: {tag_output_chinese}" + + scores = F.sigmoid(logits).numpy().reshape([bs, -1]) + class_ids_list = [] + scores_list = [] + + for b in range(bs): + index = np.argwhere(tag[b] == 1) + class_ids_list.append(index.tolist()) + scores_list.append(scores[b][index].tolist()) + + + + outputformat = { + "class_ids": class_ids_list, + "scores": scores_list, + "label_names": res[self.language] + } + + return outputformat diff --git a/ppcls/engine/engine_multimodal.py b/ppcls/engine/engine_multimodal.py new file mode 100644 index 0000000000..0b8a9197f9 --- /dev/null +++ b/ppcls/engine/engine_multimodal.py @@ -0,0 +1,145 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import platform +import paddle +import paddle.distributed as dist +from visualdl import LogWriter +from paddle import nn +import numpy as np +import random + +from ppcls.utils import logger +from ppcls.utils import save_predict_result + +from ppcls.data.utils.get_image_list import get_image_list +from ppcls.engine.train.utils import type_name +from ppcls.engine import engine +from ppcls.engine import train as train_method +from ppcls.engine.engine import ExportModel, load_dygraph_pretrain + + +class EngineMultimodal(engine.Engine): + def __init__(self, config, mode="train"): + super().__init__(config, mode) + self.train_epoch_func = train_method.train_epoch_multimodal + if mode == "train": + self.train_dataloader.collate_fn = self.train_dataloader.dataset.collect_fn_list + self.eval_dataloader.collate_fn = self.train_dataloader.dataset.collect_fn_list + + @paddle.no_grad() + def eval(self, epoch_id=0): + assert self.mode in ["train", "eval"] + self.model.eval() + eval_result = self.eval_func(self, epoch_id) + self.model.train() + return eval_result + + @paddle.no_grad() + def infer(self): + assert self.mode == "infer" and self.eval_mode == "classification" + results = [] + total_trainer = dist.get_world_size() + local_rank = dist.get_rank() + infer_imgs = self.config["Infer"]["infer_imgs"] + infer_list = self.config["Infer"].get("infer_list", None) + image_list = get_image_list(infer_imgs, infer_list=infer_list) + # data split + image_list = image_list[local_rank::total_trainer] + + batch_size = self.config["Infer"]["batch_size"] + self.model.eval() + batch_data = [] + image_file_list = [] + save_path = self.config["Infer"].get("save_dir", None) + for idx, image_file in enumerate(image_list): + with open(image_file, 'rb') as f: + x = f.read() + for process in self.preprocess_func: + x = process(x) + batch_data.append(x) + image_file_list.append(image_file) + if len(batch_data) >= batch_size or idx == len(image_list) - 1: + batch_tensor = paddle.to_tensor(batch_data) + + with self.auto_cast(is_eval=True): + tag_output = self.model.inference(batch_tensor) + + result = self.postprocess_func(*tag_output, image_file_list) + if not save_path: + logger.info(result) + results.extend(result) + batch_data.clear() + image_file_list.clear() + if save_path: + save_predict_result(save_path, results) + return results + + def export(self): + assert self.mode == "export" + use_multilabel = self.config["Global"].get("use_multilabel", False) + model = ExportModelMultiModal(self.config["Arch"], self.model, + use_multilabel) + if self.config["Global"]["pretrained_model"] is not None: + load_dygraph_pretrain(model.base_model, + self.config["Global"]["pretrained_model"]) + + model.eval() + + # for re-parameterization nets + for layer in self.model.sublayers(): + if hasattr(layer, "re_parameterize") and not getattr(layer, + "is_repped"): + layer.re_parameterize() + + save_path = os.path.join(self.config["Global"]["save_inference_dir"], + "inference") + + model = paddle.jit.to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + self.config["Global"]["image_shape"], + dtype='float32') + ]) + if hasattr(model.base_model, + "quanter") and model.base_model.quanter is not None: + model.base_model.quanter.save_quantized_model(model, + save_path + "_int8") + else: + paddle.jit.save(model, save_path) + if self.config["Global"].get("export_for_fd", False): + src_path = self.config["Global"]["infer_config_path"] + dst_path = os.path.join( + self.config["Global"]["save_inference_dir"], 'inference.yml') + shutil.copy(src_path, dst_path) + logger.info( + f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"." + ) + + +class ExportModelMultiModal(ExportModel): + def __init__(self, config, model, use_multilabel): + super().__init__(config, model, use_multilabel) + self.CLIP_model = config.get("clip","") + def forward(self, x): + if self.CLIP_model == "image": + return self.base_model.encode_image(x) + elif self.CLIP_model == "text": + return self.base_model.encode_text(x) + else: + x = self.base_model.inference(x) + return x diff --git a/ppcls/engine/train/__init__.py b/ppcls/engine/train/__init__.py index 50bf9037f4..ddd35f0e5f 100644 --- a/ppcls/engine/train/__init__.py +++ b/ppcls/engine/train/__init__.py @@ -16,3 +16,4 @@ from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl from ppcls.engine.train.train_progressive import train_epoch_progressive from ppcls.engine.train.train_metabin import train_epoch_metabin +from ppcls.engine.train.train_multimodal import train_epoch_multimodal diff --git a/ppcls/engine/train/train_multimodal.py b/ppcls/engine/train/train_multimodal.py new file mode 100644 index 0000000000..7447c5dbe7 --- /dev/null +++ b/ppcls/engine/train/train_multimodal.py @@ -0,0 +1,95 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import paddle +from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name +from ppcls.utils import profiler + + +def train_epoch_multimodal(engine, epoch_id, print_batch_step): + tic = time.time() + + if not hasattr(engine, "train_dataloader_iter"): + engine.train_dataloader_iter = iter(engine.train_dataloader) + + for iter_id in range(engine.iter_per_epoch): + # fetch data batch from dataloader + try: + batch = next(engine.train_dataloader_iter) + except Exception: + # NOTE: reset DALI dataloader manually + if engine.use_dali: + engine.train_dataloader.reset() + engine.train_dataloader_iter = iter(engine.train_dataloader) + batch = next(engine.train_dataloader_iter) + assert isinstance(batch, tuple) or isinstance(batch, list) + profiler.add_profiler_step(engine.config["profiler_options"]) + if iter_id == 5: + for key in engine.time_info: + engine.time_info[key].reset() + engine.time_info["reader_cost"].update(time.time() - tic) + + batch_size = batch[0][0].shape[0] + + engine.global_step += 1 + + # image input + with engine.auto_cast(is_eval=False): + out = forward(engine, batch) + loss_dict = engine.train_loss_func(out, batch[1]) + + # loss + loss = loss_dict["loss"] / engine.update_freq + + # backward & step opt + scaled = engine.scaler.scale(loss) + scaled.backward() + if (iter_id + 1) % engine.update_freq == 0: + for i in range(len(engine.optimizer)): + # optimizer.step() with auto amp + engine.scaler.step(engine.optimizer[i]) + engine.scaler.update() + + if (iter_id + 1) % engine.update_freq == 0: + # clear grad + for i in range(len(engine.optimizer)): + engine.optimizer[i].clear_grad() + # step lr(by step) + for i in range(len(engine.lr_sch)): + if not getattr(engine.lr_sch[i], "by_epoch", False): + engine.lr_sch[i].step() + # update ema + if engine.ema: + engine.model_ema.update(engine.model) + + # below code just for logging + # update metric_for_logger + update_metric(engine, out, batch, batch_size) + # update_loss_for_logger + update_loss(engine, loss_dict, batch_size) + engine.time_info["batch_cost"].update(time.time() - tic) + if iter_id % print_batch_step == 0: + log_info(engine, batch_size, epoch_id, iter_id) + tic = time.time() + + # step lr(by epoch) + for i in range(len(engine.lr_sch)): + if getattr(engine.lr_sch[i], "by_epoch", False) and \ + type_name(engine.lr_sch[i]) != "ReduceOnPlateau": + engine.lr_sch[i].step() + + +def forward(engine, batch): + return engine.model(*batch) diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 7ab8be4fab..ee9dfbcc00 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -21,6 +21,7 @@ from .pairwisecosface import PairwiseCosface from .dmlloss import DMLLoss from .distanceloss import DistanceLoss +from .ramloss import RAMLoss from .softtargetceloss import SoftTargetCrossEntropy from .distillationloss import DistillationCELoss from .distillationloss import DistillationGTCELoss @@ -70,7 +71,7 @@ def __call__(self, input, batch): loss_dict = {} # just for accelerate classification traing speed if len(self.loss_func) == 1: - loss = self.loss_func[0](input, batch) + loss = self.loss_func[0](*input, batch) loss_dict.update(loss) loss_dict["loss"] = list(loss.values())[0] else: diff --git a/ppcls/loss/contrasiveloss.py b/ppcls/loss/contrasiveloss.py index d27dbe22e3..e9fe03a629 100644 --- a/ppcls/loss/contrasiveloss.py +++ b/ppcls/loss/contrasiveloss.py @@ -20,6 +20,7 @@ import paddle import paddle.nn as nn +from paddle.nn import functional as F from ppcls.loss.xbm import CrossBatchMemory @@ -39,6 +40,7 @@ def __init__(self, embedding_size: int, normalize_feature=True, epsilon: float=1e-5, + is_text_image_pairs=False, feature_from: str="features"): super(ContrastiveLoss, self).__init__() self.margin = margin @@ -46,9 +48,22 @@ def __init__(self, self.normalize_feature = normalize_feature self.epsilon = epsilon self.feature_from = feature_from + self.is_text_image_pairs = is_text_image_pairs + + def text_image_pairs_constrative_loss(self, logits_per_image, + logits_per_text, labels): + total_loss = (F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels)) / 2 + + return {"Contrastive_loss": total_loss} def forward(self, input: Dict[str, paddle.Tensor], target: paddle.Tensor) -> Dict[str, paddle.Tensor]: + + if self.is_text_image_pairs: + return self.text_image_pairs_constrative_loss( + input["image"], input["text"], target) + feats = input[self.feature_from] labels = target diff --git a/ppcls/loss/ramloss.py b/ppcls/loss/ramloss.py new file mode 100644 index 0000000000..9e63fcc016 --- /dev/null +++ b/ppcls/loss/ramloss.py @@ -0,0 +1,27 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + + +# loss for RAM and RAM++ +class RAMLoss(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, loss_tag, loss_dis, loss_alignment, *kwargs): + + ## **kwargs is useless but to be compatiable with the framework + loss = loss_tag + loss_dis + loss_alignment + return {"RAMLoss": loss} diff --git a/ppcls/utils/ram/bpe_simple_vocab_16e6.txt.gz b/ppcls/utils/ram/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000..7b5088a527 Binary files /dev/null and b/ppcls/utils/ram/bpe_simple_vocab_16e6.txt.gz differ diff --git a/ppcls/utils/ram/ram_tag_list.txt b/ppcls/utils/ram/ram_tag_list.txt new file mode 100644 index 0000000000..49c840b719 --- /dev/null +++ b/ppcls/utils/ram/ram_tag_list.txt @@ -0,0 +1,4585 @@ +3D CG rendering +3D glasses +abacus +abalone +monastery +belly +academy +accessory +accident +accordion +acorn +acrylic paint +act +action +action film +activity +actor +adaptation +add +adhesive tape +adjust +adult +adventure +advertisement +antenna +aerobics +spray can +afro +agriculture +aid +air conditioner +air conditioning +air sock +aircraft cabin +aircraft model +air field +air line +airliner +airman +plane +airplane window +airport +airport runway +airport terminal +airship +airshow +aisle +alarm +alarm clock +mollymawk +album +album cover +alcohol +alcove +algae +alley +almond +aloe vera +alp +alpaca +alphabet +german shepherd +altar +amber +ambulance +bald eagle +American shorthair +amethyst +amphitheater +amplifier +amusement park +amusement ride +anchor +ancient +anemone +angel +angle +animal +animal sculpture +animal shelter +animation +animation film +animator +anime +ankle +anklet +anniversary +trench coat +ant +antelope +antique +antler +anvil +apartment +ape +app +app icon +appear +appearance +appetizer +applause +apple +apple juice +apple pie +apple tree +applesauce +appliance +appointment +approach +apricot +apron +aqua +aquarium +aquarium fish +aqueduct +arcade +arcade machine +arch +arch bridge +archaelogical excavation +archery +archipelago +architect +architecture +archive +archway +area +arena +argument +arm +armadillo +armband +armchair +armoire +armor +army +army base +army tank +array +arrest +arrow +art +art exhibition +art gallery +art print +art school +art studio +art vector illustration +artichoke +article +artifact +artist +artists loft +ash +ashtray +asia temple +asparagus +asphalt road +assemble +assembly +assembly line +association +astronaut +astronomer +athlete +athletic +atlas +atm +atmosphere +atrium +attach +fighter jet +attend +attraction +atv +eggplant +auction +audi +audio +auditorium +aurora +author +auto factory +auto mechanic +auto part +auto show +auto showroom +car battery +automobile make +automobile model +motor vehicle +autumn +autumn forest +autumn leave +autumn park +autumn tree +avatar +avenue +aviator sunglasses +avocado +award +award ceremony +award winner +shed +ax +azalea +baboon +baby +baby bottle +baby carriage +baby clothe +baby elephant +baby food +baby seat +baby shower +back +backdrop +backlight +backpack +backyard +bacon +badge +badger +badlands +badminton +badminton racket +bag +bagel +bagpipe +baguette +bait +baked goods +baker +bakery +baking +baking sheet +balance +balance car +balcony +ball +ball pit +ballerina +ballet +ballet dancer +ballet skirt +balloon +balloon arch +baseball player +ballroom +bamboo +bamboo forest +banana +banana bread +banana leaf +banana tree +band +band aid +bandage +headscarf +bandeau +bangs +bracelet +balustrade +banjo +bank +bank card +bank vault +banknote +banner +banquet +banquet hall +banyan tree +baozi +baptism +bar +bar code +bar stool +barbecue +barbecue grill +barbell +barber +barber shop +barbie +barge +barista +bark +barley +barn +barn owl +barn door +barrel +barricade +barrier +handcart +bartender +baseball +baseball base +baseball bat +baseball hat +baseball stadium +baseball game +baseball glove +baseball pitcher +baseball team +baseball uniform +basement +basil +basin +basket +basket container +basketball +basketball backboard +basketball coach +basketball court +basketball game +basketball hoop +basketball player +basketball stadium +basketball team +bass +bass guitar +bass horn +bassist +bat +bath +bath heater +bath mat +bath towel +swimwear +bathrobe +bathroom +bathroom accessory +bathroom cabinet +bathroom door +bathroom mirror +bathroom sink +toilet paper +bathroom window +batman +wand +batter +battery +battle +battle rope +battleship +bay +bay bridge +bay window +bayberry +bazaar +beach +beach ball +beach chair +beach house +beach hut +beach towel +beach volleyball +lighthouse +bead +beagle +beak +beaker +beam +bean +bean bag chair +beanbag +bear +bear cub +beard +beast +beat +beautiful +beauty +beauty salon +beaver +bed +bedcover +bed frame +bedroom +bedding +bedpan +bedroom window +bedside lamp +bee +beech tree +beef +beekeeper +beeper +beer +beer bottle +beer can +beer garden +beer glass +beer hall +beet +beetle +beige +clock +bell pepper +bell tower +belt +belt buckle +bench +bend +bengal tiger +bento +beret +berry +berth +beverage +bib +bibimbap +bible +bichon +bicycle +bicycle helmet +bicycle wheel +biker +bidet +big ben +bike lane +bike path +bike racing +bike ride +bikini +bikini top +bill +billard +billboard +billiard table +bin +binder +binocular +biology laboratory +biplane +birch +birch tree +bird +bird bath +bird feeder +bird house +bird nest +birdbath +bird cage +birth +birthday +birthday cake +birthday candle +birthday card +birthday party +biscuit +bishop +bison +bit +bite +black +black sheep +blackberry +blackbird +blackboard +blacksmith +blade +blanket +sports coat +bleacher +blender +blessing +blind +eye mask +flasher +snowstorm +block +blog +blood +bloom +blossom +blouse +blow +hair drier +blowfish +blue +blue artist +blue jay +blue sky +blueberry +bluebird +pig +board +board eraser +board game +boardwalk +boat +boat deck +boat house +paddle +boat ride +bobfloat +bobcat +body +bodyboard +bodybuilder +boiled egg +boiler +bolo tie +bolt +bomb +bomber +bonasa umbellu +bone +bonfire +bonnet +bonsai +book +book cover +bookcase +folder +bookmark +bookshelf +bookstore +boom microphone +boost +boot +border +Border collie +botanical garden +bottle +bottle cap +bottle opener +bottle screw +bougainvillea +boulder +bouquet +boutique +boutique hotel +bow +bow tie +bow window +bowl +bowling +bowling alley +bowling ball +bowling equipment +box +box girder bridge +box turtle +boxer +underdrawers +boxing +boxing glove +boxing ring +boy +brace +bracket +braid +brain +brake +brake light +branch +brand +brandy +brass +brass plaque +bread +breadbox +break +breakfast +seawall +chest +brewery +brick +brick building +wall +brickwork +wedding dress +bride +groom +bridesmaid +bridge +bridle +briefcase +bright +brim +broach +broadcasting +broccoli +bronze +bronze medal +bronze sculpture +bronze statue +brooch +creek +broom +broth +brown +brown bear +brownie +brunch +brunette +brush +coyote +brussels sprout +bubble +bubble gum +bubble tea +bucket cabinet +shield +bud +buddha +buffalo +buffet +bug +build +builder +building +building block +building facade +building material +lamp +bull +bulldog +bullet +bullet train +bulletin board +bulletproof vest +bullfighting +megaphone +bullring +bumblebee +bumper +roll +bundle +bungee +bunk bed +bunker +bunny +buoy +bureau +burial chamber +burn +burrito +bus +bus driver +bus interior +bus station +bus stop +bus window +bush +business +business card +business executive +business suit +business team +business woman +businessman +bust +butcher +butchers shop +butte +butter +cream +butterfly +butterfly house +button +buttonwood +buy +taxi +cabana +cabbage +cabin +cabin car +cabinet +cabinetry +cable +cable car +cactus +cafe +canteen +cage +cake +cake stand +calculator +caldron +calendar +calf +call +phone box +calligraphy +calm +camcorder +camel +camera +camera lens +camouflage +camp +camper +campfire +camping +campsite +campus +can +can opener +canal +canary +cancer +candle +candle holder +candy +candy bar +candy cane +candy store +cane +jar +cannon +canopy +canopy bed +cantaloupe +cantilever bridge +canvas +canyon +cap +cape +cape cod +cappuccino +capsule +captain +capture +car +car dealership +car door +car interior +car logo +car mirror +parking lot +car seat +car show +car wash +car window +caramel +card +card game +cardboard +cardboard box +cardigan +cardinal +cargo +cargo aircraft +cargo ship +caribbean +carnation +carnival +carnivore +carousel +carp +carpenter +carpet +slipper +house finch +coach +dalmatian +aircraft carrier +carrot +carrot cake +carry +cart +carton +cartoon +cartoon character +cartoon illustration +cartoon style +carve +case +cash +cashew +casino +casserole +cassette +cassette deck +plaster bandage +casting +castle +cat +cat bed +cat food +cat furniture +cat tree +catacomb +catamaran +catamount +catch +catcher +caterpillar +catfish +cathedral +cattle +catwalk +catwalk show +cauliflower +cave +caviar +CD +CD player +cedar +ceiling +ceiling fan +celebrate +celebration +celebrity +celery +cello +smartphone +cement +graveyard +centerpiece +centipede +ceramic +ceramic tile +cereal +ceremony +certificate +chain +chain saw +chair +chairlift +daybed +chalet +chalice +chalk +chamber +chameleon +champagne +champagne flute +champion +championship +chandelier +changing table +channel +chap +chapel +character sculpture +charcoal +charge +charger +chariot +charity +charity event +charm +graph +chase +chassis +check +checkbook +chessboard +checklist +cheer +cheerlead +cheese +cheeseburger +cheesecake +cheetah +chef +chemical compound +chemist +chemistry +chemistry lab +cheongsam +cherry +cherry blossom +cherry tomato +cherry tree +chess +chestnut +chicken +chicken breast +chicken coop +chicken salad +chicken wing +garbanzo +chiffonier +chihuahua +child +child actor +childs room +chile +chili dog +chimney +chimpanzee +chinaware +chinese cabbage +chinese garden +chinese knot +chinese rose +chinese tower +chip +chipmunk +chisel +chocolate +chocolate bar +chocolate cake +chocolate chip +chocolate chip cookie +chocolate milk +chocolate mousse +truffle +choir +kitchen knife +cutting board +chopstick +christmas +christmas ball +christmas card +christmas decoration +christmas dinner +christmas eve +christmas hat +christmas light +christmas market +christmas ornament +christmas tree +chrysanthemum +church +church tower +cider +cigar +cigar box +cigarette +cigarette case +waistband +cinema +photographer +cinnamon +circle +circuit +circuit board +circus +water tank +citrus fruit +city +city bus +city hall +city nightview +city park +city skyline +city square +city street +city wall +city view +clam +clarinet +clasp +class +classic +classroom +clavicle +claw +clay +pottery +clean +clean room +cleaner +cleaning product +clear +cleat +clementine +client +cliff +climb +climb mountain +climber +clinic +clip +clip art +clipboard +clipper +clivia +cloak +clogs +close-up +closet +cloth +clothe +clothing +clothespin +clothesline +clothing store +cloud +cloud forest +cloudy +clover +joker +clown fish +club +clutch +clutch bag +coal +coast +coat +coatrack +cob +cock +cockatoo +cocker +cockpit +roach +cocktail +cocktail dress +cocktail shaker +cocktail table +cocoa +coconut +coconut tree +coffee +coffee bean +coffee cup +coffee machine +coffee shop +coffeepot +coffin +cognac +spiral +coin +coke +colander +cold +slaw +collaboration +collage +collection +college student +sheepdog +crash +color +coloring book +coloring material +pony +pillar +comb +combination lock +comic +comedy +comedy film +comet +comfort +comfort food +comic book +comic book character +comic strip +commander +commentator +community +commuter +company +compass +compete +contest +competitor +composer +composition +compost +computer +computer box +computer chair +computer desk +keyboard +computer monitor +computer room +computer screen +computer tower +concept car +concert +concert hall +conch +concrete +condiment +condom +condominium +conductor +cone +meeting +conference center +conference hall +meeting room +confetti +conflict +confluence +connect +connector +conservatory +constellation +construction site +construction worker +contain +container +container ship +continent +profile +contract +control +control tower +convenience store +convention +conversation +converter +convertible +transporter +cook +cooking +cooking spray +cooker +cool +cooler +copper +copy +coral +coral reef +rope +corded phone +liquor +corgi +cork +corkboard +cormorant +corn +corn field +cornbread +corner +trumpet +cornice +cornmeal +corral +corridor +corset +cosmetic +cosmetics brush +cosmetics mirror +cosplay +costume +costumer film designer +infant bed +cottage +cotton +cotton candy +couch +countdown +counter +counter top +country artist +country house +country lane +country pop artist +countryside +coupe +couple +couple photo +courgette +course +court +courthouse +courtyard +cousin +coverall +cow +cowbell +cowboy +cowboy boot +cowboy hat +crab +crabmeat +crack +cradle +craft +craftsman +cranberry +crane +crape +crapper +crate +crater lake +lobster +crayon +cream cheese +cream pitcher +create +creature +credit card +crescent +croissant +crest +crew +cricket +cricket ball +cricket team +cricketer +crochet +crock pot +crocodile +crop +crop top +cross +crossbar +crossroad +crosstalk +crosswalk +crouton +crow +crowbar +crowd +crowded +crown +crt screen +crucifix +cruise +cruise ship +cruiser +crumb +crush +crutch +crystal +cub +cube +cucumber +cue +cuff +cufflink +cuisine +farmland +cup +cupcake +cupid +curb +curl +hair roller +currant +currency +curry +curtain +curve +pad +customer +cut +cutlery +cycle +cycling +cyclone +cylinder +cymbal +cypress +cypress tree +dachshund +daffodil +dagger +dahlia +daikon +dairy +daisy +dam +damage +damp +dance +dance floor +dance room +dancer +dandelion +dark +darkness +dart +dartboard +dashboard +date +daughter +dawn +day bed +daylight +deadbolt +death +debate +debris +decanter +deck +decker bus +decor +decorate +decorative picture +deer +defender +deity +delicatessen +deliver +demolition +monster +demonstration +den +denim jacket +dentist +department store +depression +derby +dermopathy +desert +desert road +design +designer +table +table lamp +desktop +desktop computer +dessert +destruction +detective +detergent +dew +dial +diamond +diaper +diaper bag +journal +die +diet +excavator +number +digital clock +dill +dinner +rowboat +dining room +dinner party +dinning table +dinosaur +dip +diploma +direct +director +dirt +dirt bike +dirt field +dirt road +dirt track +disaster +disciple +disco +disco ball +discotheque +disease +plate +dish antenna +dish washer +dishrag +dishes +dishsoap +Disneyland +dispenser +display +display window +trench +dive +diver +diving board +paper cup +dj +doberman +dock +doctor +document +documentary +dog +dog bed +dog breed +dog collar +dog food +dog house +doll +dollar +dollhouse +dolly +dolphin +dome +domicile +domino +donkey +donut +doodle +door +door handle +doormat +doorplate +doorway +dormitory +dough +downtown +dozer +drag +dragon +dragonfly +drain +drama +drama film +draw +drawer +drawing +drawing pin +pigtail +dress +dress hat +dress shirt +dress shoe +dress suit +dresser +dressing room +dribble +drift +driftwood +drill +drink +drinking water +drive +driver +driveway +drone +drop +droplight +dropper +drought +medicine +pharmacy +drum +drummer +drumstick +dry +duchess +duck +duckbill +duckling +duct tape +dude +duet +duffel +canoe +dumbbell +dumpling +dune +dunk +durian +dusk +dust +garbage truck +dustpan +duvet +DVD +dye +eagle +ear +earmuff +earphone +earplug +earring +earthquake +easel +easter +easter bunny +easter egg +eat +restaurant +eclair +eclipse +ecosystem +edit +education +educator +eel +egg +egg roll +egg tart +eggbeater +egret +Eiffel tower +elastic band +senior +electric chair +electric drill +electrician +electricity +electron +electronic +elephant +elevation map +elevator +elevator car +elevator door +elevator lobby +elevator shaft +embankment +embassy +embellishment +ember +emblem +embroidery +emerald +emergency +emergency service +emergency vehicle +emotion +Empire State Building +enamel +enclosure +side table +energy +engagement +engagement ring +engine +engine room +engineer +engineering +english shorthair +ensemble +enter +entertainer +entertainment +entertainment center +entrance +entrance hall +envelope +equestrian +equipment +eraser +erhu +erosion +escalator +escargot +espresso +estate +estuary +eucalyptus tree +evening +evening dress +evening light +evening sky +evening sun +event +evergreen +ewe +excavation +exercise +exhaust hood +exhibition +exit +explorer +explosion +extension cord +extinguisher +extractor +extrude +eye +eye shadow +eyebrow +eyeliner +fabric +fabric store +facade +face +face close-up +face powder +face towel +facial tissue holder +facility +factory +factory workshop +fair +fairground +fairy +falcon +fall +family +family car +family photo +family room +fan +fang +farm +farmer +farmer market +farmhouse +fashion +fashion accessory +fashion designer +fashion girl +fashion illustration +fashion look +fashion model +fashion show +fast food +fastfood restaurant +father +faucet +fault +fauna +fawn +fax +feast +feather +fedora +feed +feedbag +feeding +feeding chair +feline +mountain lion +fence +fender +fern +ferret +ferris wheel +ferry +fertilizer +festival +fiber +fiction +fiction book +field +field road +fig +fight +figure skater +figurine +file +file photo +file cabinet +fill +film camera +film director +film format +film premiere +film producer +filming +filter +fin +hand +finish line +fir +fir tree +fire +fire alarm +fire department +fire truck +fire escape +fire hose +fire pit +fire station +firecracker +fireman +fireplace +firework +firework display +first-aid kit +fish +fish boat +fish market +fish pond +fishbowl +fisherman +fishing +fishing boat +fishing net +fishing pole +fishing village +fitness +fitness course +five +fixture +fjord +flag +flag pole +flake +flame +flamingo +flannel +flap +flare +flash +flask +flat +flatfish +flavor +flea +flea market +fleet +flight +flight attendant +flip +flip-flop +flipchart +float +flock +flood +floor +floor fan +floor mat +floor plan +floor window +floral arrangement +florist +floss +flour +flow +flower +flower basket +flower bed +flower box +flower field +flower girl +flower market +fluid +flush +flute +fly +fly fishing +flyer +horse +foam +fog +foggy +foie gra +foil +folding chair +leaf +folk artist +folk dance +folk rock artist +fondant +hotpot +font +food +food coloring +food court +food processor +food stand +food truck +foosball +foot +foot bridge +football +football coach +football college game +football match +football field +football game +football helmet +football player +football stadium +football team +path +footprint +footrest +footstall +footwear +forbidden city +ford +forehead +forest +forest fire +forest floor +forest path +forest road +forge +fork +forklift +form +formal garden +formation +formula 1 +fort +fortification +forward +fossil +foundation +fountain +fountain pen +fox +frame +freckle +highway +lorry +French +French bulldog +French fries +French toast +freshener +fridge +fried chicken +fried egg +fried rice +friendship +frisbee +frog +frost +frosting +frosty +frozen +fruit +fruit cake +fruit dish +fruit market +fruit salad +fruit stand +fruit tree +fruits shop +fry +frying pan +fudge +fuel +fume hood +fun +funeral +fungi +funnel +fur +fur coat +furniture +futon +gadget +muzzle +galaxy +gallery +game +game board +game controller +ham +gang +garage +garage door +garage kit +garbage +garden +garden asparagus +garden hose +garden spider +gardener +gardening +garfield +gargoyle +wreath +garlic +garment +gas +gas station +gas stove +gasmask +collect +gathering +gauge +gazebo +gear +gecko +geisha +gel +general store +generator +geranium +ghost +gift +gift bag +gift basket +gift box +gift card +gift shop +gift wrap +gig +gin +ginger +gingerbread +gingerbread house +ginkgo tree +giraffe +girl +give +glacier +gladiator +glass bead +glass bottle +glass bowl +glass box +glass building +glass door +glass floor +glass house +glass jar +glass plate +glass table +glass vase +glass wall +glass window +glasses +glaze +glider +earth +glove +glow +glue pudding +go +go for +goal +goalkeeper +goat +goat cheese +gobi +goggles +gold +gold medal +Golden Gate Bridge +golden retriever +goldfish +golf +golf cap +golf cart +golf club +golf course +golfer +goose +gorilla +gothic +gourd +government +government agency +gown +graduate +graduation +grain +grampus +grand prix +grandfather +grandmother +grandparent +granite +granola +grape +grapefruit +wine +grass +grasshopper +grassland +grassy +grater +grave +gravel +gravestone +gravy +gravy boat +gray +graze +grazing +green +greenery +greet +greeting +greeting card +greyhound +grid +griddle +grill +grille +grilled eel +grind +grinder +grits +grocery bag +grotto +ground squirrel +group +group photo +grove +grow +guacamole +guard +guard dog +guest house +guest room +guide +guinea pig +guitar +guitarist +gulf +gull +gun +gundam +gurdwara +guzheng +gym +gymnast +habitat +hacker +hail +hair +hair color +hair spray +hairbrush +haircut +hairgrip +hairnet +hairpin +hairstyle +half +hall +halloween +halloween costume +halloween pumpkin +halter top +hamburg +hamburger +hami melon +hammer +hammock +hamper +hamster +hand dryer +hand glass +hand towel +handbag +handball +handcuff +handgun +handkerchief +handle +handsaw +handshake +handstand +handwriting +hanfu +hang +hangar +hanger +happiness +harbor +harbor seal +hard rock artist +hardback book +safety helmet +hardware +hardware store +hardwood +hardwood floor +mouth organ +pipe organ +harpsichord +harvest +harvester +hassock +hat +hatbox +hautboy +hawthorn +hay +hayfield +hazelnut +head +head coach +headlight +headboard +headdress +headland +headquarter +hearing +heart +heart shape +heat +heater +heather +hedge +hedgehog +heel +helicopter +heliport +helmet +help +hen +henna +herb +herd +hermit crab +hero +heron +hibiscus +hibiscus flower +hide +high bar +high heel +highland +highlight +hike +hiker +hiking boot +hiking equipment +hill +hill country +hill station +hillside +hindu temple +hinge +hip +hip hop artist +hippo +historian +historic +history +hockey +hockey arena +hockey game +hockey player +hockey stick +hoe +hole +vacation +holly +holothurian +home +home appliance +home base +home decor +home interior +home office +home theater +homework +hummus +honey +beehive +honeymoon +hood +hoodie +hook +jump +horizon +hornbill +horned cow +hornet +horror +horror film +horse blanket +horse cart +horse farm +horse ride +horseback +horseshoe +hose +hospital +hospital bed +hospital room +host +inn +hot +hot air balloon +hot dog +hot sauce +hot spring +hotel +hotel lobby +hotel room +hotplate +hourglass +house +house exterior +houseplant +hoverboard +howler +huddle +hug +hula hoop +person +humidifier +hummingbird +humpback whale +hunt +hunting lodge +hurdle +hurricane +husky +hut +hyaena +hybrid +hydrangea +hydrant +seaplane +ice +ice bag +polar bear +ice cave +icecream +ice cream cone +ice cream parlor +ice cube +ice floe +ice hockey player +ice hockey team +lollipop +ice maker +rink +ice sculpture +ice shelf +skate +ice skating +iceberg +icicle +icing +icon +id photo +identity card +igloo +light +iguana +illuminate +illustration +image +impala +incense +independence day +individual +indoor +indoor rower +induction cooker +industrial area +industry +infantry +inflatable boat +information desk +infrastructure +ingredient +inhalator +injection +injury +ink +inking pad +inlet +inscription +insect +install +instrument +insulated cup +interaction +interior design +website +intersection +interview +invertebrate +invitation +ipad +iphone +ipod +iris +iron +ironing board +irrigation system +island +islet +isopod +ivory +ivy +izakaya +jack +jackcrab +jacket +jacuzzi +jade +jaguar +jail cell +jam +japanese garden +jasmine +jaw +jay +jazz +jazz artist +jazz fusion artist +jeans +jeep +jelly +jelly bean +jellyfish +jet +motorboat +jewel +jewellery +jewelry shop +jigsaw puzzle +rickshaw +jockey +jockey cap +jog +joint +journalist +joystick +judge +jug +juggle +juice +juicer +jujube +jump rope +jumpsuit +jungle +junkyard +kale +kaleidoscope +kangaroo +karaoke +karate +karting +kasbah +kayak +kebab +key +keycard +khaki +kick +kilt +kimono +kindergarden classroom +kindergarten +king +king crab +kiss +kit +kitchen +kitchen cabinet +kitchen counter +kitchen floor +kitchen hood +kitchen island +kitchen sink +kitchen table +kitchen utensil +kitchen window +kitchenware +kite +kiwi +knee pad +kneel +knife +rider +knit +knitting needle +knob +knocker +knot +koala +koi +ktv +laboratory +lab coat +label +labrador +maze +lace +lace dress +ladder +ladle +ladybird +lagoon +lake +lake district +lake house +lakeshore +lamb +lamb chop +lamp post +lamp shade +spear +land +land vehicle +landfill +landing +landing deck +landmark +landscape +landslide +lanyard +lantern +lap +laptop +laptop keyboard +larva +lasagne +laser +lash +lasso +latch +latex +latte +laugh +launch +launch event +launch party +laundromat +laundry +laundry basket +laundry room +lava +lavender +lawn +lawn wedding +lawyer +lay +lead +lead singer +lead to +leader +leak +lean +learn +leash +leather +leather jacket +leather shoe +speech +lecture hall +lecture room +ledge +leftover +leg +legend +legging +legislative chamber +lego +legume +lemon +lemon juice +lemonade +lemur +lens +lens flare +lentil +leopard +leotard +tights +leprechaun +lesson +letter +mailbox +letter logo +lettering +lettuce +level +library +license +license plate +lichen +lick +lid +lie +life belt +life jacket +lifeboat +lifeguard +lift +light fixture +light show +light switch +lighting +lightning +lightning rod +lilac +lily +limb +lime +limestone +limo +line +line art +line up +linen +liner +lion +lip balm +lipstick +liquid +liquor store +list +litchi +live +livestock +living room +living space +lizard +load +loading dock +loafer +hallway +locate +lock +lock chamber +locker +loft +log +log cabin +logo +loki +long hair +longboard +loom +loop +lose +lottery +lotus +love +loveseat +luggage +lumber +lumberjack +lunch +lunch box +lush +luxury +luxury yacht +mac +macadamia +macaque +macaroni +macaw +machete +machine +machine gun +magazine +magic +magician +magnet +magnifying glass +magnolia +magpie +mahjong +mahout +maid +chain mail +mail slot +make +makeover +makeup artist +makeup tool +mallard +mallard duck +mallet +mammal +mammoth +man +management +manager +manatee +mandala +mandarin orange +mandarine +mane +manga +manger +mango +mangosteen +mangrove +manhattan +manhole +manhole cover +manicure +mannequin +manor house +mansion +mantid +mantle +manufactured home +manufacturing +manuscript +map +maple +maple leaf +maple syrup +maraca +marathon +marble +march +marching band +mare +marigold +marine +marine invertebrate +marine mammal +puppet +mark +market +market square +market stall +marriage +martial +martial artist +martial arts gym +martini +martini glass +mascara +mascot +mashed potato +masher +mask +massage +mast +mat +matador +match +matchbox +material +mattress +mausoleum +maxi dress +meal +measuring cup +measuring tape +meat +meatball +mechanic +mechanical fan +medal +media +medical equipment +medical image +medical staff +medicine cabinet +medieval +medina +meditation +meerkat +meet +melon +monument +menu +mermaid +net +mess +messenger bag +metal +metal artist +metal detector +meter +mezzanine +microphone +microscope +microwave +midnight +milestone +military uniform +milk +milk can +milk tea +milkshake +mill +mine +miner +mineral +mineral water +miniskirt +miniature +minibus +minister +minivan +mint +mint candy +mirror +miss +missile +mission +mistletoe +mix +mixer +mixing bowl +mixture +moat +mobility scooter +model +model car +modern +modern tower +moisture +mold +molding +mole +monarch +money +monitor +monk +monkey +monkey wrench +monochrome +monocycle +monster truck +moon +moon cake +moonlight +moor +moose +swab +moped +morning +morning fog +morning light +morning sun +mortar +mosaic +mosque +mosquito +moss +motel +moth +mother +motherboard +motif +sport +motor +motorbike +motorcycle +motorcycle helmet +motorcycle racer +motorcyclist +motorsport +mound +mountain +mountain bike +mountain biker +mountain biking +mountain gorilla +mountain lake +mountain landscape +mountain pass +mountain path +mountain range +mountain river +mountain snowy +mountain stream +mountain view +mountain village +mountaineer +mountaineering bag +mouse +mousepad +mousetrap +mouth +mouthwash +move +movie poster +movie ticket +mower +mp3 player +mr +mud +muffin +mug +mulberry +mulch +mule +municipality +mural +muscle +muscle car +museum +mushroom +music +music festival +music stool +music studio +music video performer +musical keyboard +musician +mussel +mustard +mythology +nacho +nail polish +nailfile +nanny +napkin +narrow +national flag +nativity scene +natural history museum +nature +nature reserve +navigation +navratri +navy +nebula +neck +neckband +necklace +neckline +nectar +nectarine +needle +neighbor +neighbourhood +neon +neon light +nerve +nest +new year +newborn +newfoundland +newlywed +news +news conference +newsstand +night +night market +night sky +night view +nightclub +nightstand +noodle +nose +noseband +note +notebook +notepad +notepaper +notice +number icon +nun +nurse +nursery +nursing home +nut +nutcracker +oak +oak tree +oar +oasis +oast house +oatmeal +oats +obelisk +observation tower +observatory +obstacle course +sea +octopus +offer +office +office building +office chair +office cubicle +office desk +office supply +office window +officer +official +oil +oil lamp +oil painting +oilrig +okra +old photo +olive +olive oil +olive tree +omelet +onion +onion ring +opal +open +opening +opening ceremony +opera +opera house +operate +operating room +operation +optical shop +orangutan +orange +orange juice +orange tree +orangery +orbit +orchard +orchestra pit +orchid +order +organization +origami +ornament +osprey +ostrich +otter +out +outcrop +outdoor +outhouse +electric outlet +outline +oval +oven +overall +overcoat +overpass +owl +oyster +teething ring +pack +package +paddock +police van +padlock +paella +pagoda +pain +paint brush +painter +paisley bandanna +palace +palette +paling +pall +palm tree +pan +pancake +panda +panel +panorama +pansy +pant +pantry +pants +pantyhose +papaya +paper +paper bag +paper cutter +paper lantern +paper plate +paper towel +paperback book +paperweight +parachute +parade +paradise +parrot +paramedic +paraquet +parasail +paratrooper +parchment +parish +park +park bench +parking +parking garage +parking meter +parking sign +parliament +parsley +participant +partner +partridge +party +party hat +pass +passage +passbook +passenger +passenger ship +passenger train +passion fruit +passport +pasta +paste +pastry +pasture +patch +patient +pattern +pavement +pavilion +paw +pay +payphone +pea +peace +peach +peacock +peak +peanut +peanut butter +pear +pearl +pebble +pecan +pedestrian +pedestrian bridge +pedestrian street +peel +peeler +pegboard +pegleg +pelican +pen +penalty kick +pencil +pencil case +pencil sharpener +pencil skirt +pendant +pendulum +penguin +peninsula +pennant +penny +piggy bank +peony +pepper +pepper grinder +peppercorn +pepperoni +perch +perform +performance +performance arena +perfume +pergola +persian cat +persimmon +personal care +personal flotation device +pest +pet +pet shop +pet store +petal +petunia +church bench +pheasant +phenomenon +philosopher +phone +phonebook +record player +photo +photo booth +photo frame +photography +physicist +physics laboratory +pianist +piano +plectrum +pick up +pickle +picnic +picnic area +picnic basket +picnic table +picture +picture frame +pie +pigeon +pilgrim +tablet +pillow +pilot +pilot boat +pin +pine +pine cone +pine forest +pine nut +pineapple +table tennis table +table tennis +pink +pint +pipa +pipe +pipe bowl +pirate +pirate flag +pirate ship +pistachio +ski slope +pocket bread +pitaya +pitbull +pitch +pitcher +pitcher plant +pitchfork +pizza +pizza cutter +pizza pan +pizzeria +placard +place +place mat +plaid +plain +plan +planet +planet earth +plank +plant +plantation +planting +plaque +plaster +plastic +plasticine +plateau +platform +platinum +platter +play +play badminton +play baseball +play basketball +play billiard +play football +play pong +play tennis +play volleyball +player +playground +playhouse +playing card +playing chess +playing golf +playing mahjong +playingfield +playpen +playroom +plaza +plier +plot +plow +plug +plug hat +plum +plumber +plumbing fixture +plume +plywood +pocket +pocket watch +pocketknife +pod +podium +poetry +poinsettia +point +pointer +poker card +poker chip +poker table +pole +polecat +police +police car +police dog +police station +politician +polka dot +pollen +pollution +polo +polo neck +polo shirt +pomegranate +pomeranian +poncho +pond +ponytail +poodle +pool +pop +pop artist +popcorn +pope +poppy +porcelain +porch +pork +porridge +portable battery +portal +portfolio +porthole +portrait +portrait session +pose +possum +post +post office +stamp +postcard +poster +poster page +pot +potato +potato chip +potato salad +potholder +potty +pouch +poultry +pound +pour +powder +power line +power plugs and sockets +power see +power station +practice +Prague Castle +prayer +preacher +premiere +prescription +show +presentation +president +press room +pressure cooker +pretzel +prince +princess +print +printed page +printer +printing +prison +produce +product +profession +professional +professor +project picture +projection screen +projector +prom +promenade +propeller +prophet +proposal +protective suit +protest +protester +publication +publicity portrait +ice hockey +pudding +puddle +puff +puffin +pug +pull +pulpit +pulse +pump +pumpkin +pumpkin pie +pumpkin seed +punch bag +punch +student +purple +push +putt +puzzle +tower +pyramid +python +qr code +quail +quarry +quarter +quartz +queen +quesadilla +queue +quiche +quilt +quilting +quote +rabbit +raccoon +race +race track +raceway +race car +racket +radar +radiator +radio +raft +rag doll +rail +railcar +railroad +railroad bridge +railway line +railway station +rain +rain boot +rainbow +rainbow trout +raincoat +rainforest +rainy +raisin +rake +ram +ramp +rapeseed +rapid +rapper +raspberry +rat +ratchet +raven +ravine +ray +razor +razor blade +read +reading +reamer +rear +rear light +rear view +rearview mirror +receipt +receive +reception +recipe +record +record producer +recorder +recording studio +recreation room +recreational vehicle +rectangle +recycling +recycling bin +red +red carpet +red flag +red panda +red wine +redwood +reed +reef +reel +referee +reflect +reflection +reflector +register +rein +reindeer +relax +release +relief +religion +religious +relish +remain +remodel +remote +remove +repair +repair shop +reptile +rescue +rescuer +research +researcher +reservoir +residence +residential neighborhood +resin +resort +resort town +restaurant kitchen +restaurant patio +restroom +retail +retriever +retro +reveal +rhinoceros +rhododendron +rib +ribbon +rice +rice cooker +rice field +ride +ridge +riding +rifle +rim +ring +riot +ripple +rise +rise building +river +river bank +river boat +river valley +riverbed +road +road sign +road trip +roadside +roast chicken +robe +robin +robot +stone +rock arch +rock artist +rock band +rock climber +rock climbing +rock concert +rock face +rock formation +rocker +rocket +rocking chair +rocky +rodent +rodeo +rodeo arena +roe +roe deer +roller +coaster +roller skate +roller skates +rolling pin +romance +romantic +roof +roof garden +room +room divider +root +root beer +rope bridge +rosary +rose +rosemary +rosy cloud +rottweiler +round table +router +row +rowan +royal +rubber stamp +rubble +rubik's cube +ruby +ruffle +rugby +rugby ball +rugby player +ruins +ruler +rum +run +runner +running shoe +rural +rust +rustic +rye +sack +saddle +saddlebag +safari +safe +safety vest +sage +sail +sailboat +sailing +sailor +squirrel monkey +sake +salad +salad bowl +salamander +salami +sale +salmon +salon +salsa +salt +salt and pepper shakers +salt lake +salt marsh +salt shaker +salute +samoyed +samurai +sand +sand bar +sand box +sand castle +sand sculpture +sandal +sandwich +sanitary napkin +santa claus +sapphire +sardine +sari +sashimi +satay +satchel +satellite +satin +sauce +saucer +sauna +sausage +savanna +saw +sawbuck +sax +saxophonist +scaffold +scale +scale model +scallop +scar +strawman +scarf +scene +scenery +schnauzer +school +school bus +school uniform +schoolhouse +schooner +science +science fiction film +science museum +scientist +scissors +wall lamp +scone +scoop +scooter +score +scoreboard +scorpion +scout +scrambled egg +scrap +scraper +scratch +screen +screen door +screenshot +screw +screwdriver +scroll +scrub +scrubbing brush +sculptor +sculpture +sea cave +sea ice +sea lion +sea turtle +sea urchin +seabass +seabed +seabird +seafood +seahorse +seal +sea view +seashell +seaside resort +season +seat +seat belt +seaweed +secretary +security +sedan +see +seed +seesaw +segway +selfie +sell +seminar +sense +sensor +server +server room +service +set +sewing machine +shadow +shake +shaker +shampoo +shape +share +shark +sharpener +sharpie +shaver +shaving cream +shawl +shear +shears +sheep +sheet +sheet music +shelf +shell +shellfish +shelter +shelve +shepherd +sherbert +shiba inu +shine +shipping +shipping container +shipwreck +shipyard +shirt +shirtless +shoal +shoe +shoe box +shoe shop +shoe tree +shoot +shooting basketball guard +shop window +shopfront +shopper +shopping +shopping bag +shopping basket +shopping cart +mall +shopping street +shore +shoreline +short +short hair +shorts +shot glass +shotgun +shoulder +shoulder bag +shovel +showcase +shower +shower cap +shower curtain +shower door +shower head +shredder +shrew +shrimp +shrine +shrub +shutter +siamese +siberia +sibling +side +side cabinet +side dish +sidecar +sideline +siding +sign +signage +signal +signature +silk +silk stocking +silo +silver +silver medal +silverware +sing +singe +singer +sink +sip +sit +sitting +skate park +skateboard +skateboarder +skater +skating rink +skeleton +sketch +skewer +ski +ski boot +ski equipment +ski jacket +ski lift +ski pole +ski resort +snowboard +skier +skiing shoes +skin +skull +skullcap +sky +sky tower +skylight +skyline +skyscraper +slalom +slate +sleigh +sleep +sleeping bag +sleepwear +sleeve +slice +slide +slider +sling +slope +slot +slot machine +sloth +slow cooker +slug +slum +smell +smile +smoke +snack +snail +snake +snapper +snapshot +snorkel +snout +snow +snow leopard +snow mountain +snowball +snowboarder +snowfield +snowflake +snowman +snowmobile +snowplow +snowshoe +snowy +soap +soap bubble +soap dispenser +soccer goalkeeper +socialite +sock +socket +soda +softball +software +solar battery +soldier +solo +solution +sombrero +song +sound +soup +soup bowl +soupspoon +sour cream +souvenir +soybean milk +spa +space +space shuttle +space station +spacecraft +spaghetti +span +wrench +spark +sparkle +sparkler +sparkling wine +sparrow +spatula +speaker +spectator +speech bubble +speed limit +speed limit sign +speedboat +speedometer +sphere +spice +spice rack +spider +spider web +spike +spin +spinach +spire +splash +sponge +spoon +sport association +sport equipment +sport team +sports ball +sports equipment +sports meet +sportswear +dot +spray +spread +spring +spring roll +sprinkle +sprinkler +sprout +spruce +spruce forest +squad +square +squash +squat +squeeze +squid +squirrel +water gun +stab +stable +stack +stadium +staff +stage +stage light +stagecoach +stain +stainless steel +stair +stairs +stairwell +stall +stallion +stand +standing +staple +stapler +star +stare +starfish +starfruit +starling +state park +state school +station +stationary bicycle +stationery +statue +steak +steak knife +steam +steam engine +steam locomotive +steam train +steamed bread +steel +steering wheel +stem +stencil +step stool +stereo +stethoscope +stew +stick +stick insect +sticker +still life +stilt +stingray +stir +stirrer +stirrup +sew +stock +stocking +stomach +stone building +stone carving +stone house +stone mill +stool +stop +stop at +stop light +stop sign +stop watch +traffic light +storage box +storage room +tank +store +storefront +stork +storm +storm cloud +stormy +stove +poker +straddle +strainer +strait +strap +straw +straw hat +strawberry +stream +street art +street artist +street corner +street dog +street food +street light +street market +street photography +street scene +street sign +street vendor +stretch +stretcher +strike +striker +string +string cheese +strip +stripe +stroll +structure +studio +studio shot +stuff +stuffed animal +stuffed toy +stuffing +stump +stunning +stunt +stupa +style +stylus +submarine +submarine sandwich +submarine water +suburb +subway +subway station +subwoofer +succulent +suede +sugar +sugar bowl +sugar cane +sugar cube +suit +suite +summer +summer evening +summit +sun +sun hat +sunbathe +sunday +sundial +sunflower +sunflower field +sunflower seed +sunglasses +sunny +sunrise +sunset +sunshade +sunshine +super bowl +sports car +superhero +supermarket +supermarket shelf +supermodel +supporter +surf +surface +surfboard +surfer +surgeon +surgery +surround +sushi +sushi bar +suspenders +suspension +suspension bridge +suv +swallow +swallowtail butterfly +swamp +swan +swan boat +sweat pant +sweatband +sweater +sweatshirt +sweet +sweet potato +swim +swim cap +swimmer +swimming hole +swimming pool +swing +swing bridge +swinge +swirl +switch +swivel chair +sword +swordfish +symbol +symmetry +synagogue +syringe +syrup +system +t shirt +t-shirt +tabasco sauce +tabby +table tennis racket +table top +tablecloth +tablet computer +tableware +tachometer +tackle +taco +tae kwon do +tai chi +tail +tailor +take +takeoff +talk +tambourine +tan +tangerine +tape +tapestry +tarmac +taro +tarp +tart +tassel +taste +tatami +tattoo +tattoo artist +tavern +tea +tea bag +tea party +tea plantation +tea pot +tea set +teach +teacher +teacup +teal +team photo +team presentation +tear +technician +technology +teddy +tee +teenager +telegraph pole +zoom lens +telescope +television +television camera +television room +television studio +temperature +temple +tempura +tennis +tennis court +tennis match +tennis net +tennis player +tennis racket +tent +tequila +terminal +terrace +terrain +terrarium +territory +test +test match +test tube +text +text message +textile +texture +thanksgiving +thanksgiving dinner +theater +theatre actor +therapy +thermometer +thermos +thermos bottle +thermostat +thicket +thimble +thing +thinking +thistle +throne +throne room +throw +throw pillow +thunder +thunderstorm +thyme +tiara +tick +ticket +ticket booth +tide pool +tie +tiger +tight +tile +tile flooring +tile roof +tile wall +tin +tinfoil +tinsel +tiramisu +tire +tissue +toast +toaster +tobacco +tobacco pipe +toddler +toe +tofu +toilet bowl +toilet seat +toiletry +tokyo tower +tomato +tomato sauce +tomato soup +tomb +tong +tongs +tool +toolbox +toothbrush +toothpaste +toothpick +topiary garden +topping +torch +tornado +tortilla +tortoise +tote bag +totem pole +totoro +toucan +touch +touchdown +tour +tour bus +tour guide +tourist +tourist attraction +tournament +tow truck +towel +towel bar +tower block +tower bridge +town +town square +toy +toy car +toy gun +toyshop +track +tractor +trade +tradition +traditional +traffic +traffic cone +traffic congestion +traffic jam +traffic sign +trail +trailer +trailer truck +train +train bridge +train car +train interior +train track +train window +trainer +training +training bench +training ground +trolley +trampoline +transformer +transparency +travel +tray +treadmill +treat +tree +tree branch +tree farm +tree frog +tree house +tree root +tree trunk +trial +triangle +triathlon +tribe +tributary +trick +tricycle +trim +trio +tripod +trombone +troop +trophy +trophy cup +tropic +trout +truck +truck driver +tub +tube +tugboat +tulip +tuna +tundra +tunnel +turbine +turkey +turn +turnip +turquoise +turret +turtle +tusk +tv actor +tv cabinet +tv drama +tv genre +tv personality +tv show +tv sitcom +tv tower +twig +twilight +twin +twine +twist +type +type on +typewriter +ukulele +ultraman +umbrella +underclothes +underwater +unicorn +uniform +universe +university +up +urban +urinal +urn +use +utensil +utility room +vacuum +valley +valve +vampire +van +vanilla +vanity +variety +vase +vault +vector cartoon illustration +vector icon +vegetable +vegetable garden +vegetable market +vegetation +vehicle +veil +vein +velvet +vending machine +vendor +vent +vespa +vessel +vest +vet +veteran +veterinarians office +viaduct +video +video camera +video game +videotape +view mirror +vigil +villa +village +vine +vinegar +vineyard +violence +violet +violin +violinist +violist +vision +visor +vodka +volcano +volleyball +volleyball court +volleyball player +volunteer +voyage +vulture +waffle +waffle iron +wagon +wagon wheel +waist +waiter +waiting hall +waiting room +walk +walking +walking cane +wall clock +wallpaper +walnut +walrus +war +warehouse +warm +warning sign +warrior +warship +warthog +wash +washer +washing +washing machine +wasp +waste +waste container +watch +water +water bird +water buffalo +water cooler +water drop +water feature +water heater +water level +water lily +water park +water pipe +water purifier +water ski +water sport +water surface +water tower +watercolor +watercolor illustration +watercolor painting +waterfall +watering can +watermark overlay stamp +watermelon +waterproof jacket +waterway +wave +wax +weapon +wear +weather +vane +web +webcam +wedding +wedding ring +wedding bouquet +wedding cake +wedding couple +wedding invitation +wedding party +wedding photo +wedding photographer +wedding photography +wedding reception +wedge +weed +weight +weight scale +welder +well +western food +western restaurant +wet +wet bar +wet suit +wetland +wetsuit +whale +whale shark +wheat +wheat field +wheel +wheelchair +wheelie +whipped cream +whisk +whisker +whiskey +whistle +white +white house +white wine +whiteboard +wicket +wide +wield +wig +Wii +Wii controller +wild +wildebeest +wildfire +wildflower +wildlife +willow +wind +wind chime +wind farm +wind turbine +windmill +window +window box +window display +window frame +window screen +window seat +window sill +wiper +windshield +windy +wine bottle +wine cooler +wine cabinet +wine cellar +wine glass +wine rack +wine tasting +winery +wing +winter +winter melon +winter morning +winter scene +winter sport +winter storm +wire +wisteria +witch +witch hat +wok +wolf +woman +wood +wood duck +wood floor +wood wall +wood-burning stove +wooden spoon +woodland +woodpecker +woodworking plane +wool +job +work card +workbench +worker +workplace +workshop +world +worm +worship +wound +wrap +wrap dress +wrapping paper +wrestle +wrestler +wrinkle +wristband +write +writer +writing +writing brush +writing desk +yacht +yak +yard +yellow +yoga +yoga mat +yoghurt +yoke +yolk +youth +youth hostel +yurt +zebra +zebra crossing +zen garden +zip +zipper +zombie +zongzi +zoo \ No newline at end of file diff --git a/ppcls/utils/ram/ram_tag_list_chinese.txt b/ppcls/utils/ram/ram_tag_list_chinese.txt new file mode 100644 index 0000000000..3f61dc0b84 --- /dev/null +++ b/ppcls/utils/ram/ram_tag_list_chinese.txt @@ -0,0 +1,4585 @@ +三维CG渲染 +3d眼镜 +算盘 +鲍鱼 +修道院 +肚子 +学院 +附件 +事故 +手风琴 +橡子 +丙烯颜料 +表演 +行动 +动作电影 +活动 +演员 +改编本 +添加 +胶带 +调整 +成人 +冒险 +广告 +天线 +有氧运动 +喷雾罐 +爆炸头 +农业 +帮助 +空调 +空调系统 +风向标 +飞机客舱 +飞机模型 +机场 +航线 +客机 +飞行员 +飞机 +飞机窗口 +机场 +机场跑道 +航站楼 +飞艇 +航展 +过道 +警报 +闹钟 +信天翁 +唱片 +唱片封面 +酒精 +壁龛 +水藻 +胡同/球道 +杏仁 +芦荟 +高山 +羊驼 +字母表 +德国牧羊犬 +圣坛 +琥珀 +救护车 +秃鹰 +美国短毛猫 +紫水晶 +圆形剧场 +扩音器 +游乐园 +游乐设施 +锚 +古老的 +海葵 +天使 +角 +动物 +动物雕塑 +动物收容所 +动画片 +动画电影 +动画师 +动漫 +脚踝 +短袜 +周年庆 +风衣 +蚂蚁 +羚羊 +古董 +鹿角 +铁砧 +公寓 +猿 +应用程序 +应用图标 +出现 +外观 +开胃菜 +掌声 +苹果 +苹果汁 +苹果派 +苹果树 +苹果酱 +设备 +约定 +通道 +杏子 +围裙 +浅绿色 +水族馆 +观赏鱼 +渡槽 +游乐中心 +商场游戏机 +拱门 +拱桥 +考古现场 +射箭 +群岛 +建筑师 +建筑设计 +档案 +拱门 +地区 +竞技场 +争论 +手臂 +穿山甲 +臂章 +扶手椅 +衣柜 +盔甲 +军队 +军事基地 +坦克 +阵列 +逮捕 +箭头 +艺术 +艺术展 +美术馆 +艺术印刷品 +艺术学校 +艺术工作室 +艺术矢量插图 +洋蓟 +文章 +手工艺品 +艺术家 +艺术阁楼 +灰 +烟灰缸 +亚洲寺庙 +芦笋 +沥青道路 +组装 +集会 +生产流水线 +协会 +宇航员 +天文学家 +运动员 +运动 +地图集 +自助取款机 +大气层 +中庭 +连接 +战斗机 +参加 +吸引力 +全地形车 +茄子 +拍卖 +奥迪汽车 +音频 +礼堂 +极光 +作者 +汽车厂 +汽车修理工 +汽车零件 +车展 +汽车展厅 +汽车电池 +汽车制造 +汽车模型 +汽车 +秋天 +秋天的森林 +秋天的叶子 +秋天的公园 +秋天的树 +阿凡达 +林荫大道 +飞行员太阳镜 +牛油果 +奖品 +颁奖典礼 +获奖者 +棚 +斧头 +杜鹃花 +狒狒 +婴儿 +奶瓶 +婴儿车 +婴儿衣服 +小象 +婴儿食品 +婴儿座椅 +迎婴派对 +背后/后面 +背景 +背光 +背包 +后院 +培根 +徽章 +獾 +荒地 +羽毛球运动 +羽毛球拍 +袋子 +面包圈 +风笛 +法棍 +诱饵 +焙烤食品 +面包师 +面包店 +烘焙 +烤盘 +平衡 +平衡车 +阳台 +球 +球池 +芭蕾舞女演员 +芭蕾舞 +芭蕾舞演员 +芭蕾舞裙 +气球 +气球拱门 +棒球手 +舞厅 +竹子 +竹林 +香蕉 +香蕉面包 +香蕉叶子 +香蕉树 +乐队 +创可贴 +绷带 +头巾 +束发带 +刘海 +手镯 +栏杆 +五弦琴 +银行 +银行卡 +银行金库 +纸币 +横幅/旗帜 +宴会 +宴会厅 +榕树 +包子 +洗礼 +酒吧 +条形码 +高脚凳 +烧烤 +烧烤架 +杠铃 +理发师 +理发店 +芭比娃娃 +驳船 +咖啡师 +树皮 +大麦 +谷仓 +仓鸮 +挡光板 +桶 +路障 +屏障 +手推车 +酒保 +棒球 +棒球基地 +棒球棒 +棒球帽 +棒球场 +棒球比赛 +棒球手套 +棒球投手 +棒球队 +棒球制服 +地下室 +罗勒 +水盆 +篮子 +篮子 +篮球 +篮球篮板 +篮球教练 +篮球场 +篮球比赛 +篮球框 +篮球运动员 +篮球馆 +篮球队 +贝斯 +低音吉他 +低音喇叭 +贝斯手 +球棒/球拍 +浴室 +水浴加热器 +浴垫 +浴巾 +泳装 +浴袍 +浴室 +浴室配件 +浴室柜 +浴室门 +浴室镜子 +浴室水槽 +卫生纸 +浴室窗户 +蝙蝠侠 +棒子 +接连猛打/击球员 +电池 +战斗 +战绳 +战舰 +海湾 +海湾大桥 +凸窗 +杨梅 +集市 +海滩 +沙滩球 +沙滩椅 +海滨别墅 +海滩小屋 +沙滩毛巾 +沙滩排球 +灯塔 +珠子 +比格犬 +鸟嘴 +烧杯 +横梁 +豆子 +豆袋椅 +豆袋 +熊 +幼熊 +胡子 +野兽 +击打/击败 +美丽的 +美丽 +美容院 +海狸 +床 +床单 +床架 +卧室 +床上用品 +便盆 +卧室窗户 +床头灯 +蜜蜂 +山毛榉 +牛肉 +养蜂人 +蜂鸣器 +啤酒 +啤酒瓶 +啤酒罐 +啤酒花园 +啤酒杯 +啤酒馆 +甜菜 +甲虫 +米色 +时钟 +甜椒 +钟楼 +皮带 +皮带扣 +长凳 +弯曲 +孟加拉虎 +盒饭 +贝雷帽 +浆果 +停泊位 +饮料 +围嘴 +拌饭 +圣经 +比熊 +自行车 +自行车头盔 +自行车车轮 +自行车骑士 +坐浴盆 +大本钟 +自行车道 +自行车道 +自行车赛 +骑车 +比基尼 +比基尼上衣 +账单 +台球 +广告牌 +台球台 +垃圾箱 +活页夹 +双筒望远镜 +生物学实验室 +双翼飞机 +桦木 +桦树 +鸟 +鸟池 +喂鸟器 +鸟舍 +鸟巢 +鸟池 +鸟笼 +出生 +生日 +生日蛋糕 +生日蜡烛 +生日贺卡 +生日聚会 +饼干 +主教 +野牛 +钻头 +咬 +黑色 +黑山羊 +黑莓 +乌鸦 +黑板 +铁匠 +叶片/刀片 +毯子/覆盖层 +运动外套 +看台 +搅拌机 +祝福 +窗帘 +眼罩 +闪光 +暴风雪 +块 +博客 +血 +开花 +花 +女装衬衫 +吹 +吹风机 +河豚 +蓝色 +蓝色艺术家 +蓝松鸦 +蓝天 +蓝莓 +蓝知更鸟 +猪 +板子 +板擦 +棋盘游戏 +木板路 +船 +船甲板 +船屋 +桨 +乘船 +浮标 +山猫 +躯干 +身体冲浪板 +健美运动员 +水煮鸡蛋 +锅炉 +饰扣式领带 +门闩 +炸弹 +轰炸机 +披肩榛鸡 +骨骼 +篝火 +阀盖 +盆景 +书 +书籍封面 +书柜 +文件夹 +书签 +书架 +书店 +远程拾音器 +推动 +靴子 +边界 +边境牧羊犬 +植物园 +瓶 +瓶盖 +开瓶器 +螺旋开瓶器 +三角梅 +巨石 +花束 +时装店 +精品酒店 +鞠躬/蝴蝶结 +领结 +弓形窗 +碗 +保龄球运动 +保龄球馆 +保龄球 +保龄球设备 +盒子 +箱形梁桥 +箱龟 +拳击手 +内裤 +拳击 +拳击手套 +拳击台 +男孩 +支撑物 +支架 +辫子 +大脑 +刹车 +刹车灯 +树枝 +商标 +白兰地 +黄铜 +黄铜牌匾 +面包 +面包箱 +休息 +早餐 +防浪堤 +胸部 +啤酒厂 +砖块 +砖建筑物 +墙 +砖块 +婚纱 +新娘 +新郎 +伴娘 +桥 +缰绳 +公文包 +明亮的 +边沿 +钻头 +广播 +西兰花 +青铜 +铜牌 +青铜雕塑 +青铜雕像 +胸针 +小溪 +扫帚 +肉汤 +棕色 +棕熊 +巧克力蛋糕 +早午餐 +浅黑肤色的女人 +刷子 +郊狼 +包菜 +气泡 +泡泡糖 +珍珠奶茶 +斗柜 +盾牌 +芽 +佛 +水牛 +自助餐 +昆虫 +建造 +建造者 +建筑 +积木 +建筑立面 +建筑材料 +灯 +牛 +斗牛犬 +子弹 +动车 +公告栏 +防弹背心 +斗牛 +扩音器 +斗牛场 +大黄蜂 +保险杠 +卷/地形起伏 +捆 +蹦极 +双层床 +地堡/击球 +兔子 +浮标 +书桌 +墓室 +燃烧 +玉米煎饼 +公交车 +公交车司机 +公交车内部 +公交车站 +公交车站 +公交车窗户 +灌木 +商业 +名片 +业务主管 +商务西装 +业务团队 +女商人 +商人 +半身像 +屠夫 +肉铺 +孤峰 +黄油 +奶油 +蝴蝶 +蝴蝶馆 +按钮 +梧桐树 +购买 +出租车 +小屋 +卷心菜 +小屋/机舱 +守车 +储藏柜 +橱柜 +电缆 +缆车 +仙人掌 +咖啡馆 +食堂 +笼子 +蛋糕 +蛋糕台 +计算器 +大锅 +日历 +小腿 +通话 +电话亭 +书法 +平静的 +摄像机 +骆驼 +相机 +相机镜头 +迷彩 +露营 +露营者 +篝火 +露营 +营地 +校园 +罐 +开罐器 +运河 +金丝雀 +癌症 +蜡烛 +烛台 +糖果 +块状糖 +柺杖糖 +糖果店 +拐杖 +罐子 +大炮 +树冠/顶棚 +四柱床 +香瓜 +悬臂桥 +帆布 +峡谷 +帽子 +斗篷 +科德角 +卡布奇诺 +胶囊 +队长 +捕获 +车 +汽车经销商 +车门 +汽车内饰 +车标 +后视镜 +停车场 +汽车座椅 +车展 +洗车 +车窗 +焦糖 +卡片 +纸牌游戏 +纸板 +纸板盒 +羊毛衫 +红衣凤头鸟 +货物 +货运飞机 +货船 +加勒比 +康乃馨 +狂欢节 +食肉动物 +旋转木马 +鲤鱼 +木匠 +地毯 +拖鞋 +红雀 +长途客车 +斑点狗 +航空母舰 +胡萝卜 +胡萝卜蛋糕 +携带 +手推车 +纸箱/纸盒 +卡通 +卡通人物 +卡通插图 +卡通风格 +雕刻 +容器 +现金 +腰果 +赌场 +砂锅 +磁带 +盒式录音机 +石膏绷带 +铸造 +城堡 +猫 +猫窝 +猫粮 +猫器具 +猫架 +地下墓穴 +双体船 +美洲狮 +握着/抓着 +捕手 +毛毛虫 +鲶鱼 +教堂 +牛 +猫步 +走秀 +菜花 +洞穴 +鱼子酱 +光盘 +CD播放器 +雪松 +天花板 +吊扇 +庆祝 +庆典 +名人 +芹菜 +大提琴 +手机 +水泥 +墓地 +中心装饰品 +蜈蚣 +陶瓷 +瓷砖 +麦片 +仪式 +证书 +链条 +链锯 +椅子 +升降椅 +躺椅 +木屋 +圣杯 +粉笔 +房间 +变色龙 +香槟酒 +香槟杯 +冠军 +锦标赛 +吊灯 +婴儿换尿布台 +通道 +皴裂处 +小教堂 +人物雕塑 +木炭 +充电 +充电器 +战车 +慈善机构 +慈善活动 +魅力 +图表 +追逐 +底盘 +检查/支票 +支票簿 +棋盘 +检查表 +欢呼声 +鼓励/啦啦队 +奶酪 +奶酪汉堡 +奶酪蛋糕 +猎豹 +厨师 +化合物 +化学家 +化学 +化学实验室 +旗袍 +樱桃 +樱花 +樱桃番茄 +樱桃树 +国际象棋 +栗子 +鸡 +鸡胸肉 +鸡笼 +鸡肉沙拉 +鸡翅 +鹰嘴豆 +小衣橱 +吉娃娃 +孩子 +童星 +孩子的房间 +红番椒 +辣热狗 +烟囱 +黑猩猩 +瓷器 +白菜 +中国园林 +中国结 +月季 +中国塔 +炸薯条/炸薯条 +花栗鼠 +凿子 +巧克力 +巧克力棒 +巧克力蛋糕 +巧克力碎片 +巧克力饼干 +巧克力牛奶 +巧克力慕斯 +松露 +唱诗班 +厨房刀 +砧板 +筷子 +圣诞节 +圣诞球 +圣诞贺卡 +圣诞装饰 +圣诞晚宴 +平安夜 +圣诞帽 +圣诞灯 +圣诞市场 +圣诞装饰 +圣诞树 +菊花 +教堂 +教堂塔 +苹果酒 +雪茄 +雪茄盒 +香烟 +烟盒 +腰带 +电影院 +摄影师 +肉桂 +圆 +电路 +电路板 +马戏团 +水箱 +柑橘类水果 +城市 +城市公交 +市政厅 +城市夜景 +城市公园 +城市天际线 +城市广场 +城市街道 +城墙 +城市景观 +蛤蜊 +单簧管 +扣子 +班级 +经典 +教室 +锁骨 +爪子 +黏土 +陶器 +清洁 +洁净室 +清洁工人 +清洁用品 +清晰的 +栓 +克莱门氏小柑橘 +客户端 +悬崖 +爬 +爬山 +登山者 +诊所 +夹子 +剪贴画 +剪贴板 +快速帆船 +君子兰 +斗篷 +木底鞋 +特写 +壁橱 +布 +穿衣 +衣服 +晒衣夹 +晒衣绳 +服装店 +云 +云雾森林 +多云 +三叶草 +小丑 +小丑鱼 +俱乐部 +离合器 +手拿包 +煤炭 +海岸 +外套 +衣帽架 +玉米 +公鸡 +凤头鹦鹉 +可卡犬 +驾驶 +蟑螂 +鸡尾酒 +小礼服 +鸡尾酒调制器 +鸡尾酒桌 +可可 +椰子 +椰子树 +咖啡 +咖啡豆 +咖啡杯 +咖啡机 +咖啡店 +咖啡壶 +棺材 +法国白兰地 +螺旋 +硬币 +可口可乐 +滤器 +冷的 +卷心菜沙拉 +合作 +拼贴画 +收藏品 +大学生 +牧羊犬 +碰撞 +颜色 +涂色书 +染色材料 +矮种马 +柱子 +梳子 +密码锁 +喜剧演员 +喜剧 +喜剧电影 +彗星 +舒服 +安慰食物 +漫画书 +漫画人物 +连环画 +指挥官 +评论员 +社区 +通勤 +公司 +指南针 +比赛 +比赛 +竞争者 +作曲家 +作文 +堆肥 +电脑 +电脑机箱 +电脑椅 +电脑桌 +键盘 +计算机显示器 +计算机房 +电脑屏幕 +机箱 +概念车 +音乐会 +音乐厅 +贝壳 +混凝土 +调味品 +避孕套 +独立产权的公寓 +指挥 +锥形物 +会议 +会议中心 +会议厅 +会议室 +五彩纸屑 +冲突 +合流 +连接 +连接器 +温室 +星座 +建筑工地 +建筑工人 +包含 +容器 +集装箱船 +大陆 +轮廓 +合同 +控制 +控制塔 +便利店 +集会 +交谈 +转换器 +可转换的 +输送机 +厨师/烹饪 +烹饪 +烹饪喷雾剂 +炊具 +凉的 +冷却器 +铜 +一本/一册 +珊瑚 +珊瑚礁 +粗绳 +有线电话 +酒 +威尔士矮脚狗 +瓶塞 +软木板 +鸬鹚 +玉米 +玉米田 +玉米面包 +角落 +小号 +飞檐 +燕麦片 +围栏 +走廊 +紧身衣 +化妆品 +化妆刷 +化妆镜 +角色扮演 +服装 +服装电影设计师 +婴儿床 +小屋 +棉花 +棉花糖 +沙发 +倒计时 +柜台 +台面 +最佳乡村歌手 +乡村别墅 +乡村公路 +乡村流行歌手 +农村 +双门小轿车 +夫妇/两人/几个 +情侣写真 +小胡瓜 +课程 +球场 +法院 +院子 +堂兄弟 +工作服 +奶牛 +母牛的颈铃 +牛仔 +牛仔靴 +牛仔帽 +螃蟹 +蟹肉 +裂纹 +摇篮 +工艺 +工匠 +蔓越莓 +起重机 +黑纱 +厕所 +板条箱 +火山口湖 +龙虾 +蜡笔 +奶油乳酪 +奶油罐 +创建 +生物 +信用卡 +新月形 +新月形面包 +山顶 +全体船员 +蟋蟀 +板球用球 +板球队 +板球队员 +钩边 +克罗克电锅 +鳄鱼 +庄稼 +露脐上衣 +交叉 +横木 +十字路口 +相声 +人行横道 +油煎面包块 +乌鸦 +撬棍 +人群 +拥挤的 +皇冠 +阴极射线管屏幕 +耶稣受难像 +巡游 +游轮 +巡洋艇 +面包屑 +压坏 +拐杖 +水晶 +幼兽 +立方体 +黄瓜 +球杆 +袖口 +袖扣 +烹饪 +农田 +杯子 +纸杯蛋糕 +丘比特 +马路牙子 +旋度 +卷发器 +无籽葡萄干 +货币 +咖喱 +窗帘 +曲线 +软垫 +顾客 +切 +餐具 +自行车 +骑自行车 +龙卷风 +汽缸 +铙钹 +柏树 +柏树 +达克斯猎狗 +水仙花 +匕首 +大丽花 +萝卜 +乳制品 +雏菊 +大坝 +损害 +潮湿的 +跳舞 +舞池 +舞蹈室 +舞者 +蒲公英 +黑暗 +黑暗 +飞镖 +圆靶 +指示板 +日期 +女儿 +黎明 +天床上 +日光 +门栓 +死亡 +辩论 +碎片 +玻璃水瓶 +甲板 +双层巴士 +装饰 +装修/装饰 +装饰画 +鹿 +后卫 +神 +熟食 +投递 +拆迁 +怪兽 +演示 +兽窝/休闲室 +牛仔夹克 +牙医 +百货商店 +抑郁症 +德比 +皮肤病 +沙漠 +沙漠公路 +设计 +设计师 +桌子/表格 +台灯 +桌面 +台式电脑 +甜点 +破坏 +侦探 +洗涤剂 +露水 +仪表盘 +钻石 +尿布 +尿布包 +杂志 +死 +饮食 +挖掘机 +数字 +数字时钟 +莳萝 +晚餐 +小船 +餐厅 +晚宴 +餐桌 +恐龙 +浸 +文凭 +指引 +导演 +尘埃 +越野摩托车 +泥土地 +泥土路 +泥路/土路 +灾难 +信徒 +迪斯科舞厅 +迪斯科灯秋 +迪斯科舞厅 +疾病 +盘子 +碟形天线 +洗碗机 +抹布 +菜肴 +洗碗液 +迪斯尼乐园 +自动售货机 +展示 +陈列窗 +壕沟 +潜水 +潜水员 +跳水板 +纸杯 +流行音乐播音员 +杜宾犬 +码头 +医生 +文件 +纪录片 +狗 +狗窝 +犬种 +狗项圈 +狗粮 +狗窝 +洋娃娃 +美元 +玩偶之家 +洋娃娃 +海豚 +穹顶 +住宅 +多米诺骨牌 +驴 +甜甜圈 +涂鸦 +门 +门把手 +受气包 +门牌 +门口 +宿舍 +面团 +市中心 +推土机 +拖 +龙 +蜻蜓 +排水沟 +剧本 +戏剧电影 +画 +抽屉里 +图画/画画 +图钉 +辫子 +连衣裙/特定场合的服装 +礼帽 +正装衬衫 +皮鞋 +大礼服 +梳妆台 +更衣室 +运球 +漂移 +浮木 +钻 +饮品/喝 +饮用水 +开车 +司机 +车道 +无人机 +水滴/下降 +吊灯 +滴管 +干旱 +药物 +药店 +鼓 +鼓手 +鸡腿 +干的 +公爵夫人 +鸭子 +鸭嘴兽 +小鸭子 +布基胶带 +伙计 +二重唱 +粗呢 +独木舟 +哑铃 +饺子 +沙丘 +扣篮 +榴莲 +黄昏 +灰尘 +垃圾车 +簸箕 +羽绒被 +DVD +染料 +鹰 +耳朵 +御寒耳罩 +耳机 +耳塞 +耳环 +地震 +画架 +复活节 +复活节兔子 +复活节彩蛋 +吃 +餐厅 +泡芙 +日食 +生态系统 +编辑 +教育 +教育家 +鳗鱼 +蛋 +蛋卷 +蛋挞 +打蛋器 +白鹭 +埃菲尔铁塔 +橡皮筋 +上级 +电椅 +电钻 +电工 +电 +电子 +电子器件 +大象 +高度图 +电梯 +电梯轿厢 +电梯门 +电梯大堂 +电梯井 +路堤 +大使馆 +装饰 +灰烬 +会徽 +刺绣 +翡翠 +紧急 +紧急服务 +紧急车辆 +情感 +帝国大厦 +搪瓷 +外壳/围墙 +茶几 +能源 +订婚 +订婚戒指 +引擎 +机舱 +工程师 +工程 +英国短毛猫 +乐团 +回车键 +演艺人员 +娱乐 +娱乐中心 +入口 +入口大厅 +信封 +马术 +设备 +橡皮擦 +二胡 +侵蚀 +自动扶梯 +食用蜗牛 +浓缩咖啡 +房地产 +河口 +桉树 +晚上 +晚礼服 +夜光 +傍晚天空 +晚上的太阳 +事件 +常绿的 +母羊 +挖掘 +运动 +排气罩 +展览 +出口 +探险者 +爆炸 +延长线 +灭火器 +排气扇 +挤压 +眼睛 +眼影 +眉 +眼线笔 +布料 +纺织品商店 +外观 +脸 +脸部特写 +蜜粉 +毛巾 +面巾纸架 +设施 +工厂 +工厂车间 +集市 +露天市场 +仙女 +猎鹰 +秋天 +家庭 +家庭轿车 +全家福 +家庭房 +风扇/扇子 +尖牙 +农场 +农民 +农民市场 +农舍 +时尚 +时尚配饰 +时装设计师 +时尚的女孩 +时装插图 +时装大片 +时装模特 +时装表演 +快餐 +西式快餐 +父亲 +水龙头 +故障 +动物 +小鹿 +传真 +宴会 +羽毛 +软呢帽 +饲料 +一餐 +饲养 +喂养的椅子 +猫科 +美洲狮 +栅栏 +芬达 +蕨类植物 +雪貂 +摩天轮 +渡船 +肥料 +节日 +纤维 +小说 +小说书 +田野/场地/野外 +田间道路 +无花果 +打架 +花样滑冰运动员 +小雕像 +文件 +档案照片 +文件柜 +填满 +胶片相机 +电影导演 +电影格式 +电影首映礼 +电影制片人 +拍摄 +过滤器 +鳍 +手 +终点线 +冷杉 +冷杉树 +火 +火灾报警 +消防部门 +消防车 +消防通道 +消防水带 +火坑 +消防站 +爆竹 +消防队员 +壁炉 +烟花 +烟花表演 +急救箱 +鱼 +鱼船 +海鲜市场 +鱼塘 +鱼缸 +渔夫 +钓鱼 +渔船 +渔网 +钓鱼 +渔村 +健身 +健身课程 +五个 +固定装置 +峡湾 +国旗 +旗杆 +小薄片 +火焰 +火烈鸟 +法兰绒 +拍打 +耀斑 +闪光 +烧瓶 +平 +比目鱼 +风味 +跳蚤 +跳蚤市场 +舰队 +飞行 +空中乘务员 +翻转 +触发器 +翻转图 +浮动 +群 +洪水 +地板/地面 +落地扇 +脚垫 +楼层平面图 +落地窗 +插花艺术 +花店 +牙线 +面粉 +流动 +花 +花篮 +花坛 +花箱 +花田 +花童 +花卉市场 +流体 +冲洗 +长笛 +飞 +飞行钓鱼 +传单 +马 +泡沫 +雾 +多雾的 +鹅肝酱 +箔纸 +折椅 +树叶 +民间艺术家 +民间舞蹈 +民间摇滚艺术家 +方旦糖 +火锅 +圣洗池 +食物 +食用色素 +美食广场 +食品加工机 +小吃摊 +快餐车 +桌上足球 +脚 +人行桥 +足球 +足球教练 +大学橄榄球赛 +足球比赛 +足球场 +足球比赛 +橄榄球头盔 +足球运动员 +足球场 +足球队 +小路 +脚印 +脚踏板 +台座 +鞋子 +故宫 +浅滩 +额头 +森林 +森林大火 +森林地面 +森林小路 +森林公路 +锻造 +餐叉 +叉车 +表格 +园林 +队列/形成物 +F1方程式赛车 +堡垒 +碉堡 +追逐 +化石 +粉底 +喷泉 +钢笔 +狐狸 +框架 +雀斑 +高速公路 +卡车 +法国 +法国斗牛犬 +薯条 +法式吐司 +化妆水 +冰箱 +炸鸡 +煎蛋 +炒饭 +友谊 +飞盘 +青蛙 +霜 +结霜 +严寒 +结冰 +水果 +水果蛋糕 +水果盘 +水果市场 +水果沙拉 +水果摊 +果树 +水果商店 +油炸食品 +煎锅 +软糖 +燃料 +吸烟罩 +有趣的 +葬礼 +真菌 +漏斗 +毛皮衣服 +毛皮大衣 +家具 +蒲团 +小工具 +枪口 +星云/星系 +美术馆 +游戏 +游戏棋盘 +游戏手柄 +火腿 +团伙 +车库 +车库门 +手工模型 +垃圾 +花园 +花园芦笋 +橡胶软管 +花园蜘蛛 +园丁 +园艺 +加菲猫 +滴水嘴 +花环 +大蒜 +衣服 +气体 +加油站 +煤气炉 +防毒面具 +收集 +聚集 +测量仪器 +露台 +齿轮 +壁虎 +艺妓 +凝胶 +百货商店 +发电机 +天竺葵 +幽灵 +礼物 +礼品袋 +礼品篮 +礼物盒 +礼品卡 +礼品商店 +礼物包装 +演唱会 +杜松子酒 +姜 +姜饼 +姜饼屋 +银杏树 +长颈鹿 +女孩 +给 +冰川 +角斗士 +玻璃珠 +玻璃瓶 +玻璃碗 +玻璃箱 +玻璃建筑 +玻璃门 +玻璃地板 +玻璃屋 +玻璃罐 +玻璃板 +玻璃桌子 +玻璃花瓶 +玻璃墙 +玻璃窗 +眼镜 +光滑面 +滑翔机 +地球 +手套 +发光 +汤圆 +去 +袭击 +球门 +守门员 +山羊 +羊奶酪 +戈壁 +护目镜/墨镜 +黄金 +金牌 +金门大桥 +金毛猎犬 +金鱼 +高尔夫运动 +高尔夫球帽 +高尔夫球车 +高尔夫球杆 +高尔夫球场 +高尔夫球手 +鹅 +大猩猩 +哥特式 +葫芦 +政府 +政府机构 +礼服 +毕业生 +毕业典礼 +谷物 +逆戟鲸 +大奖赛 +祖父 +祖母 +祖父母 +花岗岩 +格兰诺拉麦片 +葡萄 +西柚 +葡萄酒 +草 +蚱蜢 +草原 +长满草的 +擦菜器 +坟墓 +碎石 +墓碑 +肉汁 +调味汁瓶 +灰色 +吃草 +放牧 +绿色 +绿色植物 +欢迎 +问候 +贺卡 +灰狗 +网格 +筛子 +烧烤架 +格栅 +烤鳗鱼 +磨 +研磨机 +粗燕麦粉 +杂货袋 +洞穴 +地松鼠 +群体 +合影 +小树林 +生长 +牛油果酱 +警卫 +看门狗 +宾馆 +客房 +指南 +豚鼠 +吉他 +吉他手 +海湾 +海鸥 +枪 +高达 +谒师所 +古筝 +健身房 +体操运动员 +栖息地 +黑客 +冰雹 +头发 +头发颜色 +发胶 +毛刷 +发型 +发夹 +发网 +发夹 +发型 +一半 +礼堂 +万圣节 +万圣节服装 +万圣节南瓜 +露背装 +汉堡 +汉堡包 +哈密瓜 +锤子 +吊床 +阻碍 +仓鼠 +烘手机 +放大镜 +擦手巾 +手提包 +手球 +手铐 +手枪 +手帕 +把手 +手锯 +握手 +倒立 +手写 +汉服 +悬挂 +飞机库 +衣架 +幸福 +海港 +斑海豹 +硬摇滚艺术家 +精装书 +建筑工人 +硬件 +五金店 +硬木 +硬木地板 +口琴 +管风琴 +羽管键琴 +收获 +收割机 +坐垫/搁脚凳/草丛 +帽子 +帽盒 +双簧管 +山楂 +干草 +干草地 +榛子 +头 +主教练 +大灯 +床头板 +头饰 +海岬 +总部 +听力 +心脏 +心形 +热能 +加热器 +帚石楠 +树篱 +刺猬 +脚后跟 +直升机 +直升机机场 +头盔 +帮助 +母鸡 +指甲花 +药草 +兽群 +寄居蟹 +英雄 +苍鹭 +芙蓉花 +芙蓉花 +隐藏/隐蔽处 +高杠 +高跟鞋 +高地 +突出 +徒步旅行 +徒步旅行者 +徒步靴 +登山设备 +山丘 +丘陵地 +别墅 +山坡 +印度教寺庙 +铰链 +臀部 +嘻哈艺人 +河马 +历史学家 +历史遗迹 +历史 +曲棍球 +冰球馆 +曲棍球比赛 +曲棍球运动员 +曲棍球棒 +锄头 +洞 +假日 +冬青树 +海参 +家/住宅 +家用电器 +基地 +家居装饰 +室内设计 +内政部 +家庭影院 +家庭作业 +鹰嘴豆泥 +蜂蜜 +蜂窝 +蜜月 +风帽 +连帽衫 +挂钩/勾住 +跳 +地平线 +犀鸟 +长角牛 +大黄蜂 +震惊 +恐怖电影 +马鞍褥 +马车 +马场 +骑马 +马背 +马蹄铁 +软管 +医院 +医院病床 +病房 +主持人 +小旅馆 +热 +热气球 +热狗 +辣椒酱 +温泉 +旅馆 +酒店大堂 +酒店房间 +电炉 +沙漏 +房子 +房子外部 +室内植物 +悬滑板 +吼 +蜷缩 +拥抱 +呼啦圈 +人 +增湿器 +蜂鸟 +座头鲸 +打猎 +狩猎小屋 +障碍 +飓风 +哈士奇 +小屋 +鬣狗 +混合物 +绣球花 +消火栓 +水上飞机 +冰 +冰袋 +北极熊 +冰洞 +冰淇淋 +冰淇淋蛋卷 +冰淇淋商店 +冰块 +浮冰 +冰球运动员 +冰球队 +棒棒糖 +制冰机 +溜冰场 +冰雕 +冰架 +溜冰鞋 +滑冰 +冰山 +冰柱 +糖衣/酥皮 +图标 +身份证照片 +身份证 +冰屋 +光/灯光/光线 +鬣蜥蜴 +照亮 +插图 +形象 +黑斑羚 +熏香 +独立日 +个人 +室内 +划船器 +电磁炉 +工业区 +工业 +步兵 +充气艇 +服务台 +基础设施 +成分 +吸入器 +注射 +受伤 +墨水 +印泥 +小湖湾 +题词 +昆虫 +安装 +乐器/器械 +绝缘杯 +互动 +室内设计 +网站 +十字路口 +面试 +无脊椎动物 +邀请 +平板电脑 +苹果手机 +苹果音乐播放器 +虹膜 +铁 +熨衣板 +灌溉系统 +岛 +小岛 +等足类动物 +象牙 +常青藤 +居酒屋 +千斤顶 +帝王蟹/蟹 +夹克衫 +按摩浴缸 +玉 +美洲虎 +监狱牢房 +果酱 +日式花园 +茉莉花 +下巴 +松鸦 +爵士乐 +爵士乐艺术家 +爵士融合艺术家 +牛仔裤 +吉普车 +果冻 +果冻豆 +水母 +喷气式飞机 +摩托艇 +珠宝 +珠宝 +珠宝店 +拼图游戏 +人力车 +赛马骑师 +赛马帽 +慢跑 +联合的 +记者 +操纵杆 +法官 +水壶 +玩杂耍 +果汁 +榨汁器 +枣子 +跳绳 +连身裤 +丛林 +废品堆放场 +羽衣甘蓝 +万花筒 +袋鼠 +卡拉ok +空手道 +卡丁车运动 +旧城区 +皮船 +烤肉串 +按键/钥匙 +门卡 +卡其色 +踢 +苏格兰裙 +和服 +幼儿园教室 +幼儿园 +国王 +帝王蟹 +亲吻 +工具包 +厨房 +厨房橱柜 +厨房台面 +厨房地板 +厨房抽油烟机 +厨房岛 +厨房水槽 +厨房桌子 +厨房用具 +厨房窗户 +厨房用具 +风筝 +猕猴桃 +护膝 +跪下 +餐刀 +骑手 +编织 +编织针 +球形把手 +门环 +结 +考拉 +锦鲤 +ktv +实验室 +实验室外套 +标签 +拉布拉多 +迷宫 +网眼织物 +蕾丝连衣裙 +梯子 +长柄杓 +瓢虫 +环礁湖 +湖泊 +湖区 +湖边小屋 +湖岸 +羊肉 +羊排 +灯柱 +灯罩 +矛 +土地 +陆地车辆 +废物填埋 +着陆 +降落甲板 +地标 +风景 +山崩 +挂带 +灯笼 +腿/大腿 +笔记本电脑 +笔记本键盘 +幼体 +烤宽面条 +激光 +睫毛 +套索 +门闩 +乳胶 +拿铁咖啡 +笑 +发射 +发布会 +举办会议 +自助洗衣店 +洗衣房 +洗衣篮 +洗衣房 +熔岩 +薰衣草 +草坪 +草坪婚礼 +律师 +躺 +引领 +主唱 +通向 +领袖 +泄漏 +倾斜/倚靠 +学习 +皮带 +皮革 +皮夹克 +皮鞋 +演讲 +演讲厅 +教学室 +窗台 +剩饭 +腿 +传说 +紧身裤/秋裤 +立法院 +乐高 +豆类 +柠檬 +柠檬汁 +柠檬水 +狐猴 +镜头 +眩光 +扁豆 +豹 +紧身连衣裤 +紧身裤袜 +小妖精 +课程 +信函 +信箱 +信的标志 +刻字 +生菜 +水平 +图书馆 +许可证 +车牌 +地衣 +舔 +盖子 +躺着 +安全带 +救生衣 +救生艇 +救生员 +提起 +灯具 +灯光秀 +电灯开关 +照明/照明设备 +闪电 +避雷针 +淡紫色 +百合 +肢体 +石灰 +石灰石 +豪华轿车 +线条 +艺术线条 +排队 +亚麻 +邮轮 +狮子 +润唇膏 +口红 +液体 +酒类商店 +列表 +荔枝 +生活 +家畜 +客厅 +生活空间 +蜥蜴 +负载 +装卸码头 +游手好闲的人 +走廊 +定位 +锁 +闸室 +储物柜 +阁楼 +原木 +小木屋 +标志 +洛基 +长头发 +冲浪板 +隐约显现/织布机 +环状 +遗失 +彩票 +莲花 +爱 +双人沙发 +行李 +木材 +伐木工人 +午餐 +午餐盒 +郁郁葱葱的 +奢侈品 +豪华游艇 +雨衣 +澳洲胡桃 +短尾猿 +通心粉 +金刚鹦鹉 +弯刀 +机器 +机枪 +杂志 +魔法 +魔术师 +磁铁 +放大镜 +木兰花 +喜鹊 +麻将 +象夫 +女仆 +邮件 +邮件槽 +制作 +改造 +化妆师 +化妆工具 +野鸭 +野鸭 +槌棒 +哺乳动物 +猛犸象 +男人 +管理 +经理 +海牛 +曼荼罗 +橘子 +普通话 +鬃毛 +漫画 +食槽 +芒果 +山竹果 +红树林 +曼哈顿 +检修孔 +井盖 +修指甲 +人体模型 +庄园主宅 +大厦 +螳螂 +地幔 +活动房层 +制造业 +手稿 +地图 +枫木 +枫叶 +枫糖浆 +沙球 +马拉松 +大理石 +行进 +行进乐队 +母马 +金盏花 +水兵 +海洋无脊椎动物 +海洋哺乳动物 +木偶 +标志 +集市 +市场广场 +市场摊位 +结婚 +武术 +武术家 +武术馆 +马提尼 +马丁尼酒杯 +睫毛膏 +吉祥物 +土豆泥 +搅碎机 +面具/口罩 +按摩 +桅杆 +地垫 +斗牛士 +比赛 +火柴盒 +衣料 +床垫 +陵墓 +长裙 +一餐 +量杯 +卷尺 +肉类 +肉丸 +机械师 +机械风扇 +奖牌 +媒体 +医疗设备 +医学图像 +医务人员 +医药箱 +中世纪的 +麦地那市 +冥想 +猫鼬 +赛事 +香瓜 +纪念碑 +菜单 +美人鱼 +网 +肮脏 +信使袋 +金属 +金属艺术家 +金属探测器 +计量器 +中层楼 +麦克风 +显微镜 +微波炉 +午夜 +里程碑 +军装 +牛奶 +牛奶罐 +奶茶 +奶昔 +磨坊 +矿井 +矿工 +矿物质 +矿泉水 +迷你 +微缩模型 +面包车 +部长 +小型货车 +薄荷 +薄荷糖 +镜子 +小姐 +投掷物 +任务 +槲寄生 +混合 +搅拌机 +搅拌碗 +混合物 +护城河 +电动踏板车 +模型/模特 +汽车模型 +现代 +现代大厦 +潮湿 +模具 +模具 +鼹鼠 +君主 +钱 +监控器 +和尚 +猴子 +活动扳手 +黑白照片 +独轮脚踏车 +怪物卡车 +月亮 +月饼 +月光 +沼泽 +驼鹿 +拖把 +助力车 +早晨 +晨雾 +晨光 +朝阳 +砂浆 +马赛克 +清真寺 +蚊子 +藓类植物 +汽车旅馆 +蛾 +母亲 +主板 +主题 +动作 +电动机 +摩托车 +摩托车 +摩托车头盔 +摩托车赛车手 +骑摩托车的人 +赛车运动 +土堆 +山 +山地自行车 +山地自行车员 +山地自行车运动 +山地大猩猩 +山湖 +山景观 +山口 +山路 +山脉 +山区河流 +山雪 +山间溪流 +山景城 +山村 +登山者 +登山包 +鼠标/鼠 +鼠标垫 +捕鼠器 +嘴 +漱口水 +移动 +电影海报 +电影票 +割草机 +mp3播放器 +先生 +泥 +松饼 +马克杯 +桑树 +覆盖物 +骡子 +直辖市 +壁画 +肌肉 +肌肉车 +博物馆 +蘑菇 +音乐 +音乐节 +音乐凳子 +音乐工作室 +音乐录影带表演者 +音乐键盘 +音乐家 +贻贝 +芥末 +神话 +烤干酪辣味玉米片 +指甲油 +指甲锉 +保姆 +餐巾 +狭窄的 +国旗 +基督诞生的场景 +自然历史博物馆 +自然 +自然保护区 +导航 +九夜节 +海军 +星云 +脖子 +围颈带/领口 +项链 +领口 +花蜜 +油桃 +针状物 +邻居 +与某处邻近的地区 +霓虹灯 +霓虹灯 +神经 +巢 +新年 +新生的 +纽芬兰 +新婚 +新闻 +记者招待会 +报摊 +晚上 +夜市 +夜空 +夜景 +夜总会 +床头柜 +面条 +鼻子 +鼻羁 +注解 +笔记本 +记事本 +信纸 +公告 +数字图标 +修女 +护士 +托儿所 +养老院 +螺母 +胡桃夹子 +橡木 +橡树 +桨 +绿洲 +烘干室 +燕麦片 +燕麦 +方尖塔 +观察塔 +天文台 +超越障碍训练场 +海洋 +章鱼 +提供 +办公室 +办公大楼 +办公椅 +办公室隔间 +办公桌 +办公用品 +办公室的窗户 +军官 +行政官员 +石油 +油灯 +油画 +石油钻台 +秋葵 +老照片 +橄榄 +橄榄油 +橄榄树 +煎蛋卷 +洋葱 +洋葱圈 +蛋白石 +开阔的/张开 +开始 +开幕式 +歌剧 +歌剧院 +操作 +手术室 +操作 +眼镜店 +猩猩 +橙子/橙色 +橙汁 +橙树 +橘园 +轨道 +果园 +乐池 +兰花 +订单 +组织 +折纸 +点缀 +鱼鹰 +鸵鸟 +水獭 +外面的 +露头 +户外 +厕所 +电源插头 +大纲 +椭圆形 +烤箱 +整体 +大衣 +天桥 +猫头鹰 +牡蛎 +橡皮环 +包裹 +包/包装/包裹 +围场 +警车 +挂锁 +肉菜饭 +宝塔 +疼痛 +油漆刷 +画家 +佩斯利印花大手帕 +宫殿 +调色板 +栅栏 +棺罩 +棕榈树 +平底锅 +煎饼 +熊猫 +面板 +全景 +三色堇 +喘息 +储藏室 +裤子 +连裤袜 +木瓜 +纸 +纸袋 +切纸机 +纸灯笼 +纸盘子 +纸巾 +平装书 +压纸器 +降落伞 +游行 +天堂 +鹦鹉 +护理人员 +长尾小鹦鹉 +滑翔伞 +伞兵 +羊皮纸 +教区 +公园 +公园长椅 +停车 +停车场 +停车费 +停车标志 +议会 +欧芹/香菜 +参与者 +合作伙伴 +帕特里奇 +聚会 +派对帽 +通过 +通道 +存折 +乘客 +客船 +旅客列车 +百香果 +护照 +面食 +粘贴 +糕点 +牧场 +补丁 +病人 +图案/款式 +人行道/硬路面 +大帐篷 +爪子 +支付 +付费电话 +豌豆 +和平 +桃子 +孔雀 +山峰/尖顶 +花生 +花生酱 +梨 +珍珠 +卵石 +山核桃 +行人 +人行天桥 +步行街 +果皮 +削皮器 +小钉板 +木质腿 +鹈鹕 +笔/围栏 +点球 +铅笔 +铅笔盒 +卷笔刀 +铅笔裙 +吊坠 +钟摆 +企鹅 +半岛 +锦标旗 +便士 +储蓄罐 +牡丹 +胡椒/辣椒 +胡椒研磨机 +胡椒子 +意大利辣香肠 +栖息/鲈鱼 +表演 +表演 +表演舞台 +香水 +绿廊 +波斯猫 +柿子 +个人护理 +个人漂浮装置 +害虫 +宠物 +宠物店 +宠物店 +花瓣 +佩妮 +教堂的长椅 +野鸡 +现象 +哲学家 +电话 +电话簿 +留声机 +照片 +照相亭 +相框 +摄影 +物理学家 +物理实验室 +钢琴家 +钢琴 +选择 +捡起 +泡菜 +野餐 +野餐区 +野餐篮 +野餐桌 +图片 +相框 +馅饼 +鸽子 +朝圣者 +药片 +枕头 +飞行员 +领航艇 +别针 +松树 +松果 +松林 +松子 +菠萝 +乒乓球桌 +乒乓球 +粉色 +一品脱的量 +琵琶 +管子 +管碗 +海盗 +海盗旗 +海盗船 +阿月浑子 +滑雪场 +口袋里的面包 +火龙果 +斗牛犬 +球场 +大水罐 +猪笼草 +干草叉 +披萨 +披萨刀 +比萨锅 +披萨店 +招牌 +地方 +餐具垫 +格子 +平原 +示意图 +行星 +行星地球 +厚木板 +植物 +种植园 +种植 +匾额 +石膏 +塑料 +橡皮泥 +高原 +平台 +白金 +大浅盘 +玩/演奏/运动 +打羽毛球 +打棒球 +打篮球 +玩台球 +踢足球 +玩乒乓球 +打网球 +打排球 +选手/运动员 +操场 +剧场 +扑克牌 +下棋 +打高尔夫球 +打麻将 +运动场 +护栏 +游戏室 +广场 +钳子 +故事情节 +犁 +插头 +插头帽 +李子 +水管工 +卫生洁具 +羽毛 +夹板 +口袋 +怀表 +随身小折刀 +圆荚体 +乐队指挥台 +诗歌 +一品红 +指/朝向 +指针 +扑克卡 +筹码 +扑克表 +杆/柱 +臭猫 +警察 +警车 +警犬 +警察局 +政治家 +圆点 +花粉 +污染 +马球 +马球领 +马球衬衫 +石榴 +波美拉尼亚的 +雨披 +池塘 +马尾辫 +贵宾犬 +池 +流行 +流行艺术家 +爆米花 +教皇 +罂粟 +瓷 +玄关 +猪肉 +粥 +便携式电池 +门户网站 +投资组合 +汽门 +肖像 +肖像会话 +摆姿势拍照 +负鼠 +帖子 +邮局 +邮票 +明信片 +海报 +海报页 +锅/罐/陶盆 +土豆 +土豆片 +土豆沙拉 +布垫子 +便壶 +袋 +家禽 +英镑 +倾泻 +粉末 +电源线 +电源插头及插座 +权力看 +电站 +练习 +布拉格城堡 +祈祷 +牧师 +首映 +处方 +显示 +演讲 +总统 +新闻发布室 +高压锅 +椒盐卷饼 +王子 +公主 +打印 +打印页面 +打印机 +印刷 +监狱 +农产品/生产 +产品 +职业 +专业的 +教授 +项目图片 +投影屏幕 +投影仪 +毕业舞会 +散步 +螺旋桨 +先知 +建议 +防护服 +抗议 +抗议者 +出版 +宣传画像 +冰上曲棍球 +布丁 +水坑 +泡芙 +角嘴海雀 +哈巴狗 +拉 +讲坛 +脉冲 +泵 +南瓜 +南瓜饼 +南瓜种子 +拳击吊袋 +拳头猛击/穿孔 +学生 +紫色 +推 +轻轻一击 +谜题 +塔 +金字塔 +大蟒 +二维码 +鹌鹑 +采石场 +季度 +石英 +女王 +油炸玉米粉饼 +队列 +乳蛋饼 +被子 +绗缝 +引用 +兔子 +浣熊 +比赛 +赛道 +水沟/跑道 +赛车 +球拍 +雷达 +散热器 +广播 +木筏/橡皮艇 +布娃娃 +栏杆/铁轨 +轨道车 +铁道 +铁路桥梁 +轨道线 +火车站 +雨 +雨靴 +彩虹 +虹鳟鱼 +雨衣 +热带雨林 +多雨的 +葡萄干 +耙子 +公羊 +斜坡 +油菜籽 +快速 +说唱歌手 +树莓 +老鼠 +棘轮 +乌鸦 +峡谷 +雷 +剃须刀 +锋利的 +阅读 +阅读材料 +钻孔器 +后面 +尾灯 +后视图 +后视镜 +收据 +收到 +接待 +配方 +记录 +唱片制作人 +记录器/竖笛 +录音室 +娱乐室 +休闲车 +矩形 +回收 +回收站 +红色 +红地毯 +红旗 +红熊猫 +红酒 +红木 +芦苇 +礁石 +卷轴 +裁判 +倒影 +倒影 +反射器 +注册 +控制 +驯鹿 +放松 +释放 +救援 +宗教 +宗教的 +享受 +保持 +改造 +遥控器 +移除 +修复 +维修店 +爬行动物 +救援 +救助者 +研究 +研究员 +储层 +住宅 +居民区 +树脂 +度假胜地 +度假小镇 +餐厅的厨房 +餐厅的露台 +厕所 +零售 +寻回犬 +制动火箭 +揭示 +犀牛 +杜鹃 +肋骨 +丝带 +大米 +电饭煲 +稻田 +骑/搭乘 +脊 +骑马 +步枪 +边缘 +环/戒指 +暴乱 +涟漪 +上升 +高层建筑 +河 +河岸 +河船 +河谷 +河床 +路 +路标 +公路旅行 +路边 +烤鸡 +长袍 +罗宾 +机器人 +石头 +岩石拱 +摇滚艺术家 +摇滚乐队 +攀岩者 +攀岩 +摇滚音乐会 +岩石表面 +岩层 +摇滚歌手 +火箭 +摇椅 +岩石 +啮齿动物 +牛仔竞技表演 +竞技舞台 +罗伊 +狍子 +辊 +过山车 +轮式溜冰鞋 +溜冰鞋 +擀面杖 +浪漫 +浪漫的 +屋顶 +屋顶花园 +房间 +房间分频器 +根 +根啤酒 +绳索桥 +念珠 +玫瑰 +迷迭香 +玫瑰色的云 +罗特韦尔犬 +圆桌 +路由器 +行 +罗文 +皇家 +橡皮图章 +废墟 +魔方 +红宝石 +莱夫 +橄榄球 +橄榄球 +橄榄球运动员 +毁坏 +尺 +朗姆酒 +跑 +跑步者 +跑步鞋 +农村的 +锈 +乡村的 +黑麦 +袋 +鞍 +鞍囊 +旅行 +安全 +安全背心 +圣人 +帆 +帆船 +航行 +水手 +松鼠猴 +缘故 +沙拉 +沙拉碗 +火蜥蜴 +意大利蒜味腊肠 +出售 +三文鱼 +沙龙 +萨尔萨舞 +盐 +盐和胡椒瓶 +盐湖 +盐沼 +盐瓶 +敬礼 +萨莫耶德人 +武士 +沙子 +沙洲 +砂箱 +沙堡 +沙雕 +凉鞋 +三明治 +卫生巾 +圣诞老人 +蓝宝石 +沙丁鱼 +莎丽 +生鱼片 +沙爹 +书包 +卫星 +缎 +酱汁 +碟子 +桑拿 +香肠 +稀树大草原 +锯 +锯木架 +萨克斯管 +萨克斯手 +脚手架 +秤/标尺 +比例模型 +扇贝 +疤痕 +稻草人 +围巾 +场景 +风景 +雪纳瑞犬 +学校 +校车 +校服 +校舍 +纵帆船 +科学 +科幻电影 +科学博物馆 +科学家 +剪刀 +壁灯 +司康饼 +勺子 +踏板车/摩托车 +分数 +记分板 +蝎子 +童子军 +炒蛋 +废弃 +刮板 +刮伤 +屏幕 +纱门 +截图 +螺杆 +螺丝刀 +长卷纸/卷轴 +擦洗 +硬毛刷 +雕塑家 +雕塑 +海洞穴 +海冰 +海狮 +海龟 +海胆 +尖吻鲈 +海底 +海鸟 +海鲜 +海马 +海豹 +海景 +海贝 +海滨度假胜地 +季节 +座位 +安全带 +海藻 +秘书 +安全 +小轿车 +看到 +种子 +跷跷板 +赛格威 +自拍 +出售 +研讨会 +感觉 +传感器 +服务器 +服务器机房 +服务 +集 +缝纫机 +影子 +摇 +瓶 +洗发水 +形状 +分享 +鲨鱼 +卷笔刀 +记号笔 +剃须刀 +剃须膏 +披肩/围巾 +剪切 +剪刀 +羊 +床单 +乐谱 +架子 +贝壳 +贝类 +避难所 +搁置 +牧羊人 +果子露 +柴犬 +发光 +航运 +集装箱 +海难 +船厂 +衬衫 +赤膊的 +浅滩 +鞋 +鞋盒 +鞋店 +鞋楦 +射击 +得分篮球后卫 +商店橱窗 +门面 +购物者 +购物 +购物袋 +购物篮 +购物车 +购物中心 +购物街 +海岸 +海岸线 +短的 +短发 +短裤 +小酒杯 +散弹枪 +肩膀 +单肩包 +铲 +陈列柜 +淋浴 +浴帽 +浴帘 +淋浴门 +淋浴头 +碎纸机 +泼妇 +虾 +神社 +灌木 +快门 +暹罗猫 +西伯利亚 +兄弟姐妹 +侧面 +边柜 +配菜 +边车 +边线 +壁板 +标志 +指示牌 +信号 +签名 +丝绸 +丝袜 +筒仓 +银 +银牌 +银器 +唱歌 +烧焦 +歌手 +水槽 +啜 +坐/放置/坐落 +坐着 +滑板公园 +滑板 +滑板者 +溜冰者 +溜冰场 +骨架 +草图 +串串 +滑雪 +滑雪靴 +滑雪设备 +滑雪服 +滑雪缆车 +滑雪杖 +滑雪胜地 +滑雪板 +滑雪 +滑雪鞋 +皮肤 +头骨 +无边便帽 +天空 +天空塔 +天窗 +天际线 +摩天大楼 +激流回旋 +石板 +雪橇 +睡眠 +睡袋 +睡衣 +袖子 +片 +滑动 +滑块 +吊索 +坡 +投币口 +老虎机 +树懒 +慢炖锅 +鼻涕虫 +贫民窟 +气味 +微笑 +烟雾/抽烟 +零食 +蜗牛 +蛇 +鲷鱼 +快照 +通气管 +鼻子 +雪 +雪豹 +雪山 +雪球 +单板滑雪者 +雪原 +雪花 +雪人 +雪地摩托 +雪犁 +雪鞋 +雪 +肥皂 +肥皂泡 +给皂器 +足球守门员 +社会名流 +短袜 +插座 +苏打水 +垒球 +软件 +太阳能电池阵列 +士兵 +独奏 +解决方案 +宽边帽 +歌曲 +声音 +汤 +汤碗 +汤匙 +酸奶油 +纪念品 +豆浆 +水疗中心 +空间 +航天飞机 +空间站 +宇宙飞船 +意大利面 +横跨 +扳手 +火花 +闪耀 +烟火 +起泡葡萄酒 +麻雀 +抹刀 +扬声器 +观众 +会话框 +速度限制 +限速标志 +快艇 +车速表 +球 +香料 +调料架 +蜘蛛 +蜘蛛网 +扣球 +旋转 +菠菜 +尖塔 +飞溅 +海绵 +勺子 +体育协会 +运动器材 +运动团队 +体育球 +体育器材 +运动会 +运动服装 +点 +喷雾 +伸展 +春天 +春卷 +撒 +洒水器 +发芽 +云杉 +云杉森林 +队 +广场 +南瓜 +蹲 +挤 +鱿鱼 +松鼠 +水枪 +刺 +稳定的 +(码放整齐的)一叠 +体育场 +工作人员 +舞台 +舞台灯 +驿马车 +弄脏 +不锈钢 +楼梯 +楼梯 +楼梯间 +摊位/小隔间 +种马 +站/矗立/摊位 +站 +主食 +订书机 +星星 +盯着 +海星 +杨桃 +燕八哥 +州立公园 +公立学校 +车站 +固定自行车 +文具 +雕像 +牛排 +牛排刀 +蒸汽 +蒸汽机 +蒸汽机车 +蒸汽火车 +馒头 +钢 +方向盘 +(花草的)茎 +模版 +梯凳 +立体声 +听诊器 +炖 +戳/条状物 +竹节虫 +贴纸 +静物画 +高跷 +黄貂鱼 +搅拌 +搅拌器 +镫 +缝 +股票 +长筒袜 +腹部 +石头建筑 +石雕 +石屋 +石磨 +凳子 +停止 +停在 +红灯 +停车标志 +秒表 +红绿灯 +存储箱 +储藏室 +罐/蓄水池 +商店 +店面 +鹳 +风暴 +暴风云 +狂风暴雨的 +炉子 +扑克 +跨骑 +过滤器 +海峡 +带 +稻草/吸管 +草帽 +草莓 +溪流 +街头艺术 +街头艺术家 +街角 +流浪狗 +街头食品 +路灯 +街市场 +街头摄影 +街景 +路标 +街头小贩 +拉伸 +担架 +罢工 +前锋 +细绳 +芝士条 +带子 +条纹 +漫步 +结构 +工作室 +影棚拍摄 +材料 +填充玩具动物 +毛绒玩具 +馅 +树桩 +惊人的 +特技 +佛塔 +风格 +手写笔 +潜艇 +潜艇形大三明治 +海底水 +郊区 +地铁 +地铁站 +低音炮 +多肉 +绒面革 +糖 +糖碗 +甘蔗 +方糖 +西装 +套房 +夏天 +夏天傍晚 +峰顶 +太阳 +太阳帽 +日光浴 +周日 +日晷 +向日葵 +向日葵田 +葵花籽 +太阳镜 +晴天 +日出 +日落 +遮阳伞 +阳光 +超级碗 +跑车 +超级英雄 +超市 +超市货架 +超模 +支持者 +冲浪 +表面 +冲浪板 +冲浪者 +外科医生 +外科手术 +环绕 +寿司 +寿司吧 +背带裤 +悬架 +吊桥 +越野车 +燕子 +燕尾蝶 +沼泽 +天鹅 +天鹅游艇 +运动裤 +防汗带 +毛衣 +运动衫 +甜的 +红薯 +游泳 +泳帽 +游泳者 +游泳洞 +游泳池 +摆动 +平转桥 +秋千 +漩涡 +开关 +转椅 +剑 +旗鱼 +象征 +对称 +犹太教堂 +注射器 +糖浆 +系统 +t恤 +t恤 +塔巴斯科辣椒酱 +虎斑 +乒乓球拍 +桌面 +桌布 +平板电脑 +餐具 +转速表 +拦截 +墨西哥煎玉米卷 +跆拳道 +太极 +尾巴 +裁缝 +拍/拿 +起飞 +说话/交谈/演讲 +手鼓 +棕褐色 +橘子 +胶带/磁带/终点线 +挂毯 +沥青碎石路面 +芋头 +篷布 +果馅饼 +流苏 +味道 +榻榻米 +纹身 +纹身艺术家 +酒馆 +茶 +茶包 +茶话会 +茶园 +茶壶 +茶具 +教 +老师 +茶杯 +水鸭 +团队合影 +团队介绍 +眼泪/撕裂/划破 +技术员 +技术 +泰迪熊 +T字形物 +青少年 +电线杆 +变焦镜头 +望远镜 +电视 +电视摄像机 +电视室 +电视演播室 +温度 +寺庙 +天妇罗 +网球 +网球场 +网球比赛 +网球网 +网球运动员 +网球拍 +帐篷 +龙舌兰酒 +终端/航站楼 +阳台 +地形 +玻璃容器 +领土 +测试 +测试赛 +试管 +文本 +短信 +纺织 +纹理 +感恩节 +感恩节晚餐 +剧院 +戏剧演员 +治疗 +温度计 +热水瓶 +暖瓶 +恒温器 +灌木丛 +顶针 +东西 +思考 +蓟 +宝座 +金銮殿 +扔 +抱枕 +雷 +雷雨 +百里香 +皇冠 +记号 +票 +售票亭 +潮池 +领带 +老虎 +紧 +瓦 +瓷砖地板 +瓦屋顶 +瓷砖墙 +锡 +锡纸 +箔 +提拉米苏 +轮胎 +纸巾 +烤面包 +烤面包机 +烟草 +烟斗 +学步的小孩 +脚趾 +豆腐 +马桶 +马桶座圈 +化妆包 +东京铁塔 +番茄 +番茄酱 +番茄汤 +墓 +钳子 +钳子 +工具 +工具箱 +牙刷 +牙膏 +牙签 +修剪成形的花园 +配料 +火炬/光源 +龙卷风 +玉米粉圆饼 +乌龟 +大手提袋 +图腾柱 +龙猫 +巨嘴鸟 +触摸 +触地 +旅行 +旅游巴士 +导游 +游客 +旅游景点 +锦标赛 +拖车 +毛巾 +毛巾杆 +大厦 +塔桥 +小镇 +城镇广场 +玩具 +玩具车 +玩具枪 +玩具店 +跑道 +拖拉机 +贸易 +传统 +传统的 +交通 +锥形交通路标 +交通拥堵 +交通堵塞 +交通标志 +小道 +预告片 +拖车 +火车 +火车桥 +火车车厢 +火车内部 +火车轨道 +火车窗口 +教练 +训练 +训练长椅 +训练场 +电车/手推车 +蹦床 +变形金刚 +透明度 +旅行 +托盘/碟子 +跑步机 +美食 +树 +树枝 +林场 +树蛙 +树屋 +树根 +树干 +试验 +三角形 +铁人三项 +部落 +支流 +戏法/特技 +三轮车 +修剪 +三人组 +三脚架 +长号 +部队 +奖杯 +奖杯 +热带 +鳟鱼 +卡车 +卡车司机 +浴缸 +管子 +拖船 +郁金香 +金枪鱼 +苔原 +隧道 +涡轮 +火鸡 +转动 +芜菁 +绿松石 +炮塔 +乌龟 +獠牙 +电视演员 +电视柜 +电视剧 +电视节目类型 +电视名人 +电视节目 +情景喜剧 +电视塔 +枝条 +黄昏 +双胞胎 +麻线 +扭 +类型 +键入 +打字机 +尤克里里 +奥特曼 +伞 +内衣 +水下 +独角兽 +制服 +宇宙 +大学 +向上 +城市 +尿壶 +瓮 +使用 +用具 +杂物间 +吸尘器/真空 +谷 +阀门 +吸血鬼 +货车 +香草 +虚荣 +种类 +花瓶/瓶 +金库 +矢量卡通插图 +矢量图标 +蔬菜 +菜园 +蔬菜市场 +植被 +车辆 +面纱 +静脉 +天鹅绒 +自动售货机 +小贩 +通风孔 +胡蜂属 +船 +背心 +兽医 +经验丰富的 +兽医办公室 +高架桥 +视频 +摄像机 +电子游戏 +录像带 +视镜 +守夜 +别墅 +村庄 +藤蔓 +醋 +葡萄园 +暴力 +紫罗兰色 +小提琴 +小提琴家 +中提琴演奏者 +愿景 +遮阳板 +伏特加 +火山 +排球 +排球场 +排球运动员 +志愿者 +航行 +秃鹰 +华夫饼干 +华夫饼机 +货车 +马车车轮 +腰 +服务员 +候机室 +等候室 +走 +步行 +手杖 +挂钟 +壁纸 +核桃 +海象 +战争 +仓库 +温暖的 +警告标志 +战士 +军舰 +疣猪 +洗 +洗衣机/垫圈 +洗 +洗衣机 +黄蜂 +浪费 +废物容器 +手表 +水 +水鸟 +水牛 +水冷却器 +水滴 +水景 +热水器 +水位 +荷花 +水上乐园 +水管 +净水器 +滑水板 +水上运动 +水面 +水塔 +水彩 +水彩插图 +水彩画 +瀑布 +喷壶 +水印叠加图章 +西瓜 +防水外套 +水路 +波浪 +蜡 +武器 +穿着 +天气 +叶片 +网 +摄像头 +婚礼 +结婚戒指 +婚礼花束 +结婚蛋糕 +新婚夫妇 +婚礼请柬 +婚礼派对 +婚纱照 +婚礼摄影师 +婚纱摄影 +婚宴 +楔 +杂草 +重量 +体重秤 +焊接工 +井 +西餐 +西餐厅 +湿 +吧台 +潜水衣 +湿地 +潜水服 +鲸鱼 +鲸鲨 +小麦 +麦田 +车轮 +轮椅 +后轮支撑车技 +生奶油 +搅拌器 +胡须 +威士忌 +哨子 +白色 +白宫 +白葡萄酒 +白板 +便门 +宽的 +挥动 +假发 +Wii +Wii手柄 +荒野 +角马 +野火 +野花 +野生动物 +柳树 +风 +风铃 +风电场 +风力涡轮机 +风车 +窗户 +窗台花盆箱 +橱窗展示 +窗框 +纱窗 +靠窗的座位 +窗台 +雨刮器 +挡风玻璃 +有风的 +酒瓶 +冷酒器 +酒柜 +酒窖 +酒杯 +酒架 +品酒 +酒庄 +翅膀 +冬天 +冬瓜 +冬天的早晨 +冬季场景 +冬季运动 +冬季风暴 +电线 +紫藤 +巫婆 +女巫帽子 +炒锅 +狼 +女人 +木头 +林鸳鸯 +木地板 +木墙 +烧木炉 +木匙 +林地 +啄木鸟 +木工刨 +羊毛 +工作 +练习卡 +工作台 +工人 +工作场所 +车间 +世界 +蠕虫 +敬拜 +伤口 +包 +裹身裙 +包装纸 +搏斗 +摔跤手 +皱纹 +腕带 +写 +作家 +手写/字迹 +毛笔 +写字桌 +游艇 +牦牛 +院子 +黄色 +瑜伽 +瑜伽垫 +酸奶 +轭 +蛋黄 +青年 +青年旅馆 +蒙古包 +斑马 +斑马线 +禅意花园 +拉链 +拉链 +僵尸 +粽子 +动物园 diff --git a/ppcls/utils/ram/ram_tag_list_threshold.txt b/ppcls/utils/ram/ram_tag_list_threshold.txt new file mode 100644 index 0000000000..f8a861ebb4 --- /dev/null +++ b/ppcls/utils/ram/ram_tag_list_threshold.txt @@ -0,0 +1,4585 @@ +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.71 +0.75 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.61 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.7 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.82 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.85 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.77 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.89 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.78 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.9 +0.65 +0.83 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.79 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.79 +0.65 +0.63 +0.65 +0.87 +0.8 +0.46 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.8 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.84 +0.65 +0.65 +0.79 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.81 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.83 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.77 +0.87 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.68 +0.65 +0.8 +0.65 +0.65 +0.75 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.8 +0.8 +0.79 +0.65 +0.85 +0.65 +0.65 +0.65 +0.9 +0.65 +0.89 +0.8 +0.65 +0.65 +0.65 +0.76 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +1 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.89 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.71 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.8 +0.8 +0.9 +0.65 +0.85 +0.8 +0.8 +0.8 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.75 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.63 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.71 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.9 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.71 +0.65 +0.8 +0.76 +0.85 +0.8 +0.65 +0.65 +0.8 +0.65 +0.79 +0.65 +0.75 +0.65 +0.8 +0.65 +0.86 +0.65 +0.65 +0.9 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.73 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.8 +0.75 +0.65 +0.65 +0.65 +0.65 +0.8 +0.85 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.6 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.74 +0.65 +0.65 +0.67 +0.65 +0.65 +0.8 +0.65 +0.65 +0.85 +0.65 +0.8 +0.65 +0.65 +0.84 +0.8 +0.8 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.89 +0.65 +0.65 +0.65 +0.83 +0.65 +0.65 +0.65 +0.65 +0.6 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.74 +0.65 +0.65 +0.66 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.88 +0.65 +0.65 +0.8 +0.65 +0.65 +0.7 +0.65 +0.65 +0.65 +0.9 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.82 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.75 +0.65 +0.7 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +1 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.71 +0.65 +0.65 +0.65 +0.79 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.82 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.87 +0.65 +0.66 +0.65 +0.84 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.5 +0.65 +0.64 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.84 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.8 +0.65 +0.85 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.73 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.82 +0.8 +0.65 +0.65 +0.65 +0.84 +0.9 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.64 +0.65 +0.65 +0.65 +0.8 +0.8 +0.87 +0.65 +0.65 +0.78 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.74 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.83 +0.89 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.8 +0.84 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.81 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.7 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.82 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.87 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.73 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.89 +0.8 +0.65 +0.9 +0.65 +1 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.89 +0.89 +0.65 +0.65 +0.65 +0.8 +0.75 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.88 +0.65 +0.8 +0.65 +0.65 +0.8 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.9 +0.57 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.8 +0.79 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.89 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.81 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.84 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.8 +0.83 +0.65 +0.65 +0.8 +0.65 +0.65 +0.72 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +1 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.69 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.71 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.85 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.9 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.85 +0.65 +0.65 +0.8 +0.65 +0.89 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.86 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.75 +0.8 +0.65 +0.8 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.82 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.83 +0.65 +0.65 +0.92 +0.89 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.75 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.85 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.87 +0.65 +0.79 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.83 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.7 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.8 +0.82 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +1 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.64 +0.65 +0.65 +0.63 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.76 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.75 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.87 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.82 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.89 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.65 +0.73 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.9 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.69 +0.65 +0.65 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.72 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.9 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.45 +0.8 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.51 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.66 +0.65 +0.8 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.75 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.66 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.65 +0.85 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.79 +0.75 +0.65 +0.65 +0.8 +0.65 +0.67 +0.8 +0.8 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.8 +0.65 +0.65 +0.9 +0.65 +0.79 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.74 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.6 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.89 +0.8 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.9 +0.75 +0.65 +0.65 +0.65 +0.8 +0.6 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.85 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.63 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.8 +0.65 +0.81 +0.8 +0.8 +0.8 +0.82 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.88 +0.65 +0.8 +0.65 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +1 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.74 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.9 +0.86 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.64 +0.65 +0.65 +0.8 +0.8 +0.65 +0.87 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.8 +0.65 +0.65 +0.75 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.85 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.71 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.73 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.75 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.88 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.9 +0.65 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.57 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.76 +1 +0.8 +0.65 +0.65 +0.58 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +1 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.8 +0.9 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.68 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.99 +0.8 +0.77 +0.65 +0.9 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.8 +0.65 +0.7 +0.65 +0.65 +0.8 +0.9 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.79 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.52 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.86 +0.65 +0.65 +0.8 +0.56 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.72 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.8 +0.6 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.89 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.75 +0.65 +0.65 +0.65 +0.65 +0.54 +1 +0.65 +0.65 +0.75 +0.65 +0.75 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.9 +0.62 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.82 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.74 +0.8 +0.65 +0.8 +0.8 +0.7 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.8 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.8 +0.8 +0.84 +0.8 +0.65 +0.65 +0.8 +0.75 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.82 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.74 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.9 +0.9 +0.65 +0.65 +0.65 +0.63 +0.82 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.74 +0.9 +0.65 +0.8 +0.65 +0.65 +0.58 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.75 +0.65 +0.65 +0.8 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.8 +0.65 +0.64 +0.65 +0.65 +0.65 +0.8 +0.87 +0.65 +0.65 +0.8 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.83 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.78 +0.65 +0.8 +0.65 +0.9 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.88 +0.8 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.77 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +1 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.85 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.68 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.5 +0.65 +0.65 +0.65 +0.8 +0.9 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.8 +0.85 +0.65 +0.77 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.9 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.88 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.82 +0.65 +0.8 +0.74 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.7 +0.7 +0.8 +0.65 +0.65 +0.65 +0.65 +0.87 +0.8 +0.65 +0.65 +0.65 +0.89 +0.85 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.8 +0.65 +0.66 +0.57 +0.65 +0.65 +0.65 +0.49 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.76 diff --git a/ppcls/utils/ram/tag2text_ori_tag_list.txt b/ppcls/utils/ram/tag2text_ori_tag_list.txt new file mode 100644 index 0000000000..11a61b68fb --- /dev/null +++ b/ppcls/utils/ram/tag2text_ori_tag_list.txt @@ -0,0 +1,3429 @@ +tennis +bear cub +observatory +bicycle +hillside +judge +watercolor illustration +granite +lobster +livery +stone +ceramic +ranch +cloth +smile +building +tattoo +cricketer +cheek +pear +source +winter +surface +spray +ceremony +magic +curve +container +fair +medicine +baby +tennis racquet +ornament +bamboo +duckling +song +safari +team presentation +daffodil +cross +toothpaste +shield +fashion model +capsule +map +creek +glass house +glass plate +siding +corner +water buffalo +bison +figure skater +diploma +tire +race +cable car +brain +gas stove +soap bubble +palette +snowboard +school child +trench coat +monk +fiber +kitchen window +sunglass +coffee +security +strawberry +penguin +tree root +loaf +engagement ring +lamb +vector cartoon illustration +sandwich +mountain village +shape +charm +fiction +knot +greenhouse +sushi +text +disaster +trophy +gang +strap +soccer game +cardinal +tee +turtle +water surface +grassland +dolphin +store +dirt +iceberg +pergola +farmer market +publicity portrait +tote bag +teenage girl +view mirror +session +commuter +dressing room +tricycle +christmas ball +headlight +police +armchair +chart +yacht +saw +printer +rock band +gingerbread house +tag +table lamp +hockey game +slope +font +wicker basket +jewelry +quarter +software +weapon +pin +worship +painter +goal +morning light +bike +baseball bat +elevator +cuisine +sausage +stunt +wrestler +statue +landing +pillar +willow tree +sea wave +chicken +peanut +muscle +bob +tv genre +bathroom window +radish +textile +pelican +marketplace +crest +elevation map +gift +parish +traffic light +campfire +fog +award winner +beach ball +mat +white house +plaster +moped +football team +solution +bicyclist +bit +playground +darkness +cake +maple leave +mold +cracker +blueberry +rubble +container ship +pedestrian bridge +snail +parrot +form +circuit +highlight +pickup truck +koala +rain +system +weather +raincoat +soccer team +windshield +thunderstorm +mike +bird house +bridge +grandfather +restroom +animation +wilderness +clown +banana +brown +braid +dining room +kindergarten +launch event +purple +school +stairwell +brooch +movie poster image +mountain river +shelf +wicket +headboard +buddha +flower field +dugout +cd +bald eagle +lagoon +seaweed +agriculture +emergency service +maple tree +parachute +continent +amusement park +remote +bun +tackle +hospital +garage door +birthday party +friendship +go +mausoleum +jeep +raccoon +step +ice hockey team +cigarette +lace dress +forest floor +mall +captain +milk +golf course +meal +picnic table +sail +volleyball +canal +terrace +computer desk +caravan +hotel +cheerleader +nurse +museum +marsh +fox +plateau +night +twin +letter logo +autumn tree +powder +convention +creature +lighthouse +shop window +jacket +stork +taxi +trade +blackboard +olive +road sign +resort +snowflake +cemetery +travel +evening dress +picnic +drink +winter morning +football player +snack +boxing glove +dinner party +airline +swing +port +wheelbarrow +bathroom sink +sweater +ambulance +gear +oil +wii controller +array +home office +car show +mixture +profession +tree frog +square +facility +coral reef +sea wall +pizza +exhibit +demolition +trout +ring +coffee shop +bracelet +bean +lip +fencing +landscape +sitting +package +metal +bust +king +hair +window seat +wildlife +trunk +greenery +stencil +fire hydrant +bridesmaid +plaza +alps +tower bridge +crop top +crossing +cinema +pedestrian crossing +family +shopping cart +stomach +church building +screen door +skater +soccer field +kettle +mussel +raindrop +candy cane +water lily +flower girl +desert +enclosure +christmas light +kitchen +caterpillar +plaid +bath +bush +mud +ballet +knee +adult +raft +sea view +cactus +office chair +overall +rim +scaffolding +pig +cover +poster page +sprinkle +chandelier +algae +traffic +surfboard +book +filming +flash +mansion +camouflage +trouser +ticket +weed +cab +trench +elephant +huddle +sphere +christmas decoration +city +launch +doll +christmas ornament +fabric +bikini +biplane +breakfast +neighbourhood +race track +foliage +avocado +school bus +footwear +highway +ocean view +art vector illustration +wall clock +curtain +teenager +kitchen area +robot +tusk +lounge chair +beam +paddle +camel +lid +world map +city view +newlywed +cargo ship +yellow +exhibition +bend +novel +wool +ontario +bread +campus +coastline +cutting board +booth +table top +carpet +beach chair +workout +street food +fun +costumer film designer +gadget +artist +fishing village +builder +violinist +iphone +spider web +traffic sign +ruin +rescue +clipboard +seal +film director +paw +nursery +intersection +tomato sauce +taste +paddy field +christmas tree +wave +stool +watering can +rug +daytime +subway station +craft +pine forest +black +planet +motif +christmas market +glass window +college +wheat +damage +rectangle +picture frame +chess +guest room +street corner +religion +seed +puzzle +freeway +beauty +ocean +watch +mother +garage +quote +dj +supporter +hip hop artist +muffin +eiffel tower +cash +firefighter +cauliflower +bunker +sled +manicure +shark +stall +jungle +family home +tour bus +chimney +touchdown +roundabout +coyote +street scene +tank +wedding dress +mantle +bedroom window +coconut +chapel +goat +living space +rock wall +polka dot +railway +mandala +mango +lesson +mountain landscape +team photo +bookshelf +meter +bulldog +evening sun +stick +card +pink +fish pond +paint +pill +cart +pea +van +album +football college game +mountain pass +doughnut +ski slope +match +official +shadow +organ +celebration +coin +log cabin +firework display +present +twig +chef +confetti +footpath +tour +ponytail +artwork +race car +club +season +hose +pencil +aircraft +rock formation +wardrobe +participant +politician +engineer +peace +filter +sailing boat +water bottle +service dog +poodle +loki +statesman +sleeping bag +outskirt +clock +factory +oak tree +physician +color +room +stairway +company +lady +graph +faucet +tablecloth +subway train +chocolate chip cookie +headquarters +screw +goggle +halloween +city street +swirl +cord +forward +bone +bedding +archway +wig +lobby +mask +attic +kitchen table +skylight +fire +exit +oil painting +passenger +meditation +salmon +fedora +rubber stamp +orange juice +arch +scientist +stroll +manhattan +float +baseball uniform +circle +church +decker bus +competitor +zoo +basketball team +tourist +daughter +silverware +ceiling fan +birth +vase +jack +mushroom +spiral +cage +limb +salad +ad +control +earth +party +bolt +tractor +barley +wedding photo +hawk +warehouse +vegetable garden +chocolate cake +cabbage +floor window +baby shower +magnifying glass +table +stethoscope +reading +mission +croissant +gift box +rocket +forest road +cooking +suite +hill country +motorcycle +baseball player +angle +drug +sport association +championship +family portrait +florist +softball +egret +office +plywood +jockey +mosque +brunch +beanie +office building +pattern +calendar +indoor +pepper +ledge +trail +fuel +laptop computer +tennis shoe +deck chair +guitarist +barn +surgery +cartoon illustration +nebula +railroad +mountain goat +goose +car door +cheer +liquid +hardwood floor +pathway +acorn +gull +airliner +couch +lake house +spaghetti +promenade +collection +garden +bank +robin +tennis ball +peony +gymnast +lavender +deck +test +riverside +rapper +domino +bride +mouse +basil +wedding couple +ocean wave +arm +kitchen floor +grove +family member +backyard +raspberry +forest fire +officer +hibiscus +canyon +composer +signature +olive oil +hibiscus flower +rose +vector icon +sunrise +horseback +motor scooter +office worker +tradition +ingredient +washing machine +lighting +bagel +sailboat +policeman +mare +graphic +halloween pumpkin +stock +pilot +education +team +body +horse +kimono +bazaar +bag +recording studio +parsley +entrance +denim +vet +horse farm +charcoal +architecture +glass vase +puppy +estuary +television show host +city bus +shoulder +beast +balance +golfer +roadside +denim jacket +stone wall +counter top +app icon +toast +head coach +ham +warrior +gem +refrigerator +snowman +construction worker +coal +website +morning fog +mustard +human +owl +puppy dog +piggy bank +vegetation +pirate +action film +marshmallow +thanksgiving +business +disease +signage +greeting +skate park +tile +mouth +spinach +vacation +leader +shrine +walker +science fiction film +bill +rabbit +motor boat +bar +radio +barge +tail +chainsaw +gallery +rainbow +pasta +padlock +web +pastry +ink +reef +school uniform +shawl +treasure +peach +dinner table +injury +harbor +witch +car dealership +litter +gesture +documentary +marriage +sea shell +priest +dome +kit +icon +seaside +bucket +entertainment +stable +hat +puddle +sock +shopper +technology +harbour +orbit +antler +tube +flag waving +cook +tight +commander +farmland +switch +hiker +wedding ceremony +award ceremony +champion +chopstick +farmhouse +performer +spike +accident +cruise ship +passenger train +attraction +entertainer +rear view +sidewalk +parade +racing +plane +ritual +peacock +pocket +plum +drop +carrot +floor +sunset +troop +architect +coffee table +dust +outline +leather +charity event +heat +whale +laundry +coconut tree +crosswalk +pony +ant +pipe +string +coat +angel +beef +church tower +dish +pitch +cupboard +thermometer +dirt field +fireworks +minute +cane +pajama +flower garden +autumn +trash can +dachshund +banana tree +tray +moose +roadway +carnival +antenna +pole +castle wall +ram +cattle +hay +cookie +swimmer +baseball team +strait +hedge +jet +fire pit +octopus +calf +cube +opera +cardboard box +tiara +kitchen sink +prairie +bowl +galaxy +straw hat +linen +ski resort +stitch +street lamp +motorist +icicle +stain +flora +drain +kitchen cabinet +decor +bouquet +pound +interior design +nail polish +figurine +tomb +disc +twist +blouse +ribbon +figure +burger +cork +soccer goalkeeper +train bridge +drinking water +dew +baker +storm cloud +tarmac +tv drama +sponge +magnet +sailor +entry +swan +exercise +sloth +jewel +scuba diver +bite +cat tree +tent +can +tennis match +ecosystem +picket fence +palm +train car +frying pan +rally +tablet pc +reindeer +image +wolf +chin +conservatory +flood water +cityscape +beach sand +car park +pavement +farm field +swimming +winter storm +stem +pillow +inning +gorilla +desk +avenue +fern +money +pearl +train station +skillet +nap +barber +library +freezer +label +rainforest +parking sign +mirror +wing +noodle +press room +sculpture +tablet +viewer +prayer +mini +mechanic +laugh +rice field +hand +mustache +mountain road +catwalk +conference +cape +installation +musician +stream +machine +speech +crocodile +soccer match +town square +passport +post box +point +stone building +motorway +mix +dentist +businessperson +happiness +boat +vineyard +treadmill +glass wall +water droplet +coffee mug +graduate +sunflower +parliament +shepherd +movie +wine +orchard +tulip +motherboard +cup +broom +spot +drawing +polo shirt +graduation +film producer +moonlight +glow +film format +t shirt +rock face +sword +clinic +festival day +meadow +staple +pupil +training ground +rider +flower +foal +wharf +foot bridge +shooting +top +mast +police car +robe +wedding bouquet +stop sign +birthday cake +glitter +butter +scooter +tundra +superhero +pocket watch +inscription +youngster +fruit tree +movie poster +engine +foundation +motorcyclist +take +woman +antelope +country artist +road trip +typewriter +tuxedo +brand +pine +bathroom +paradise +texture +balloon +dining table +home +computer screen +actor +clip +tv tower +panorama +summit +cat +plot +eagle +dancer +pup +studio shot +tear +bird bath +classroom +bookstore +city wall +tv programme +blade +easel +buttercream +sweet +designer +diamond +handshake +herb +corn field +seafront +concrete +street artist +gas +stamp +window display +paper +note +pint +quarry +research +fixture +manager +soil +leopard +board game +ladder +stop light +island +ramp +football match +icing +drill +currency +summer evening +topping +pyramid +pomegranate +cell +ivy +squad +scenery +computer +locomotive +surf +mascot +dune +path +duck +twilight +wire +bow tie +strike +cormorant +car wash +crane +market +philosopher +alarm clock +camera +birch +greeting card +plain +clay +donut +lock +moth +laboratory +fan +violin +jazz fusion artist +mountain biker +terrain +magazine +pickup +comedy film +smartphone +film +bed +microwave oven +tournament +lawn +car window +alligator +screen +jetty +shopping bag +landscape view +cabinetry +friendly match +thing +petal +shopping center +transport +ballet dancer +shoreline +princess +car seat +parking meter +green +vodka +band +rock +costume +warning sign +strip +plaque +wheelchair +headband +ginger +dice +media +hairdresser +press +living room +stove +player +cherry +workshop +carving +embroidery +doodle +adventure +rugby player +monument +brush +marker +loft +postcard +collage +ball +professor +dresser +gig +festival +blackbird +makeup artist +video camera +sticker +peak +wildflower +santa hat +rodeo +wedding photographer +guy +staff +waterfall +operation +defender +falcon +haze +individual +gentleman +greyhound +rocking chair +rice +garbage +platter +chocolate +splash +business suit +cheetah +valley +maze +trampoline +garland +slalom +unicorn +tree stump +painting +romance +fight +alcohol +ghost +fondant +spa +shutter +death +demonstration +cotton +pier +flea market +history +savannah +fist +aisle +crew +jug +pose +anchor +teapot +boat house +business team +tripod +bee +pebble +mattress +canvas +hallway +campaign +pod +lake district +article +white +sofa +honey +marathon +pancake +tourist attraction +wedding gown +battle +shelving +sea +sheet music +pie +yarn +construction site +flyer +tie +star +lettuce +martial artist +dart +straw +reflection +conference room +temperature +rugby +mosquito +physicist +rock climber +crash +backdrop +toilet seat +sand castle +water park +toy car +waste +luxury +hangar +rv +tree trunk +board +gold +project picture +cap +cottage +relief +attire +microscope +battery +roll +line +parking garage +crystal +broadcasting +brick wall +lab +flooring +meeting +3d cg rendering +desktop computer +cowboy +sailing ship +junction +hairstyle +homework +profile +model +flower pot +street light +salt lake +maple +space +blizzard +throw +zebras +brochure +constellation +beak +kilt +pond +blue sky +sneaker +sand dune +morning sun +almond +grill +curl +basketball girl game +chameleon +toilet bowl +prince +keyboard +queen +computer monitor +writing +crown +basilica +kiss +house +parking +football competition +shell +sport equipment +comedy +baboon +vendor +rise building +wrap +food truck +cat bed +rickshaw +flare +teal +nectar +eclipse +vehicle +steam locomotive +gorge +cow +christmas card +demonstrator +memorial +towel +jewellery +train +frisbee +baseball game +fur +afternoon sun +community +sparkler +bandage +firework +dollar +pasture +video +bus +tree house +seashore +field +hamburger +souvenir +hedgehog +worm +pine cone +osprey +dinosaur +vegetable +junk +poster +army +winger +bundle +stage +growth +wedding party +service +blanket +ruler +eye +credit card +castle +diner +hut +elk +hard rock artist +nun +dog breed +nest +drama film +number icon +water tank +giraffe +altar +pavilion +tv personality +suv +street vendor +street sign +ditch +debris +foam +takeoff +spice +mountain lake +tea +orchestra +spacecraft +counter +abbey +mountain +hydrangea +racer +orange tree +tide +cowboy hat +rapid +town +wild +herd +vein +driveway +jar +bark +illustration +horror film +corn +stroller +industry +mountain stream +gym +neckline +pan +client +spectator +eggplant +camper +fawn +hoodie +meat +lemonade +food market +slum +comic book character +flower market +love +palace +gun +heel +shopping street +shooting basketball guard +family photo +rooftop +laundry basket +airport runway +horn +face mask +flight +appetizer +violet +country lane +cement +instrument +tv actor +spark +celebrity +award +country house +standing +auction +date +engagement +puck +advertisement +chair +zebra +driftwood +bumblebee +maple leaf +bonnet +orange +water tower +door +singer +floor plan +discussion +theatre +pilgrim +mug +branch +window sill +baseball pitcher +bakery +lollipop +basketball player +toilet paper +chalkboard +cabin +sign +night sky +cannon +fishing net +submarine +suit +fur coat +wine bottle +folder +street art +suspension bridge +evening sky +billboard +postage stamp +newspaper +transportation +surgeon +light +park +horizon +road +sand bar +trumpet +lounge +cloud forest +birthday celebration +balcony +anime +beehive +umbrella +goldfish +baseball cap +waterhole +ceiling +carousel +backpack +plant pot +atmosphere +sunflower field +spire +vision +woodpecker +chip +pool table +lotus flower +cone +humpback whale +reservoir +hunt +piano +plate +dining area +luggage +skier +dance floor +crow +stair +overpass +opera house +bear +jazz artist +water +vessel +cast +yard +cathedral +basketball hoop +graveyard +sound +berry +onlooker +fauna +birch tree +retail +hill +skeleton +journalist +frost +basket +nail +dusk +trash +dawn +clover +hen +volcano +basketball coach +home decor +charge +haircut +sense +university +lizard +daisy +tablet computer +grass field +prison +metal artist +bathroom mirror +window frame +chest +flavor +pop country artist +market square +monkey +blog +deer +speech bubble +dog +independence day +girl +boy +tartan +furniture +appliance +office window +fish boat +sand box +tv sitcom +drama +sleigh +depression +paper towel +baseball +protestor +grape +wedding cake +invitation +accessory +pick +grandparent +racket +tea plantation +outdoors +egg +glass bowl +sun +organization +lion +panel +station +wallpaper +helicopter +salt +vanity +patio +lunch +street performer +mountain range +soup +bacon +power station +cantilever bridge +hummingbird +shirt +rope +hip +chalk +pendant +choir +tv +lichen +railway bridge +art gallery +bartender +wagon +baby elephant +accordion +horseshoe +building site +clutch +harvest +savanna +geranium +business woman +paddock +patch +beech tree +war +suburbs +hospital bed +motorcycle racer +moss +gravel +government agency +dollar bill +father +fjord +concert +nut +wedding photography +finish line +home plate +food +nose +thumb +village +dining room table +bumper +monster +blackberry +lime +conflict +gala +wallet +wrist +hug +mermaid +lava +lawyer +folk rock artist +arena +onion +toothbrush +fashion +perfume +flip +triangle +woodland +mail +grasshopper +studio +wood floor +den +racquet +cello +lemur +astronaut +glass table +blood +dvd +planter +silver +leash +master bedroom +forest +batter +shoe +engraving +opening +product +toe +cocktail +mallard duck +bike ride +oasis +wedding ring +cinematographer +holly +autograph +fence +ice cube +cove +pineapple +aurora +glass bead +produce +apartment building +cob +miniature +cockpit +flashlight +frog +sheep +groom +steel +watermelon +clip art +paper plate +ostrich +contour +mural +cub +paisley bandanna +winery +turn +handle +satellite +post +pork +child +asphalt +grocery store +vulture +trolley +nightclub +brick +trailer +compass +cereal +cafe +cartoon character +sugar +fiction book +glass floor +umpire +guitar +hamster +protester +airplane +garment +blazer +railway line +wedding +shoe box +parking lot +construction +graduation ceremony +tram +telescope +copper +pain +autumn forest +guest house +partner +crayon +dip +boot +corridor +computer keyboard +hockey player +chicken coop +bus station +gathering +ankle +bunk bed +wood table +football coach +monarch +pharmacy +legging +mannequin +female +train track +stack +canopy +design element +grandmother +symbol +beach hut +zucchini +bomb +businessman +skyscraper +tongue +case +sparkle +highland +ballroom +prom +estate +customer +archipelago +cheese +debate +carriage +bulldozer +pumpkin +sitting room +gas station +wedding reception +camp +dog bed +tower +property +river bed +pop latin artist +fridge +wine glass +coast +beer +tow truck +fire truck +mountain bike +thigh +heron +boat ride +gondola +turquoise +lake +llama +kitty +tin +waiting room +coffee cup +socialite +guard +tap +waterway +forehead +list +erosion +box +sea lion +pollen +dam +wasp +salon +tennis tournament +flower box +aquarium +rain cloud +clothing store +lead singer +cupcake +tortoise +lettering +sport facility +dance +dog house +nature +football +rooster +footballer +railway track +crowd +fishing rod +silhouette +wind turbine +sari +bus window +cloud +charity +medal +yoga +event +veil +fashion menswear milan week +news +knife +print +screen tv +walnut +fungus +ice cream +computer mouse +play +tribe +picture +video game +business card +music festival +rack +envelope +shower +dirt road +mine +oyster +monarch butterfly +dude +fruit salad +podium +fork +lace +test match +boulder +cricket player +staircase +peninsula +shopping +popcorn +oak +market stall +pine tree +mountaineer +student +closet +hood +handstand +centerpiece +insect +patient +makeover +tennis player +sheet +park bench +apple +organism +hook +turkey +tangerine +sibling +shopping mall +bird +scarf +smoothie +net +grass +napkin +ray +eyebrow +laptop keyboard +motorbike +woman hand +oven +book cover +easter egg +microwave +sand +snapshot +soccer ball +makeup +knight +bowling ball +shower curtain +flame +lightning +running +power plant +crib +cartoon +moat +fashion girl +wedding invitation +bottle +cliff +monastery +file photo +apartment +casino +cream +sweatshirt +storm +cruise +teddy bear +shovel +wind farm +writer +dock +professional +hotel room +job +monitor +donkey +pass +interview +duchess +mark +plank +beard +zombie +trio +channel +cricket team +windmill +vest +diagram +cable +winter scene +golden gate bridge +buffalo +studio portrait +pagoda +whiskey +freight train +kite +future +steam train +phone box +headset +wood +snowboarder +paper bag +slide +grapefruit +seating +morning +bronze sculpture +theatre actor +stump +jean +landmark +jam +waist +watercolor +hammock +light fixture +ice +basin +beverage +shelter +premiere +mound +ear +bronze +sunlight +street +energy +barn door +hike +fleet +claw +beach +pepperoni +bin +trainer +buffet +archive +toddler +referee +bay window +dove +production company +evening light +gate +farm +reed +fruit stand +explorer +snow storm +throw pillow +button +display case +bookcase +lead +lipstick +basketball court +cargo +ensemble +pope +clock tower +teen +speaker +rat +laptop +ski +mess +stadium +ferry boat +bunny +waterfront +downtown +sink +press conference +dinner +condiment +thread +audience +grid +car +plastic +people +barbecue +pigeon +urinal +seagull +volunteer +hockey +fir tree +pollution +trial +collar +area +meeting room +circus +yogurt +orangutan +viaduct +comedian +drone +scissor +pop rock artist +biscuit +panda +water feature +air balloon +remote control +watercolor painting +show +walk +post office +bike path +rap gangsta artist +microphone +crack +sunset sky +glass +tv show +cartoon style +stripe +foyer +signal +calligraphy +bulb +gardener +coffee bean +spider +tapestry +city skyline +necklace +kitten +traveler +veteran +frosting +fry +tennis court +tank top +butterfly house +mist +drummer +water level +scale +baseball glove +music video performer +champagne +camping +clothing +water drop +telephone box +pen +morning mist +fire engine +porch +opening ceremony +style +palm tree +fashion show +universe +scratch +axe +ottoman +explosion +rib +boutique +game +cucumber +fruit +stone bridge +nature reserve +track +train window +punch +telephone pole +velvet +sauce +moon +contrast +flamingo +bat +vending machine +ship +equestrian +shade +comforter +pallet +sparrow +wii +glaze +grocery +steeple +soccer player +contract +advertising +runner +chimpanzee +world +seat +project +chihuahua +bubble +willow +pedestal +soul hip hop artist +curb +drawer +leaf +banner +launch party +coach +government +snowball +toy +portrait +doctor +whiteboard +electronic +tiger +graffiti +column +nightstand +whistle +maxi dress +bench +wetsuit +bird feeder +football game +basketball +class +bathroom door +store window +text message +wreath +street view +binocular +pet +facade +drought +lemon +new year +night view +airplane window +specie +rule +jaw +wheat field +diet +pop artist +habitat +screenshot +scoreboard +shore +mane +quilt +ski lift +orchid +turban +christmas +airport +marina +glass door +glass bottle +restaurant +conductor +logo +sleep +tape +tomato +river bank +lilac +tooth +training +pottery +shop +steam engine +mason jar +base +procession +border +shoot +footprint +hotdog +bull +stocking +recreation +automobile model +design +country pop artist +river +retriever +department store +auditorium +sport car +supermarket +belt +cricket +window box +dress shirt +letter +residence +megaphone +pant +wildfire +bird nest +crab +swimsuit +candle +funeral +mill +national park +plant +cop +power line +perch +blue +finger +ferris wheel +globe +skateboard +helmet +movie theater +uniform +hammer +material +kid +well +butterfly +sideline +fashion fall show +planet earth +lift +male +sauna +gray +flour +sand sculpture +program +cabinet +infant +wheel +aircraft model +dough +garlic +skate +arrow +wrapping paper +ripple +lamp +iron +banknote +beaver +ferry +courtyard +bassist +countryside +steak +comfort +boxer +laundry room +campsite +brick building +golf +subway +headphone +fort +handbag +drum +flood +saddle +bass +labyrinth +needle +sun ray +app +menu +president +cardigan +dandelion +wetland +ice hockey player +number +city hall +fishing +portrait session +pug +key +art print +minister +hurdle +emergency +painting artist +flag pole +evening +purse +recipe +golf ball +coloring book +mountain peak +senior +holiday +bud +cousin +pantry +lap +skin +flag +tissue paper +ridge +wire fence +surfer +climber +photograph +sewing machine +cooler +actress +apple tree +cancer +starfish +automobile make +dumbbell +brace +tunnel +window +paint artist +composition +school student +condo +convertible +cushion +selfie +territory +guide +tree +court +shrimp +stone house +dress +eyelash +juice +broccoli +chain +tourism +mountain top +concept car +film premiere +light bulb +cafeteria +badge +flower bed +theater +root +racecar driver +basketball boy game +glove +skyline +wall +glacier +airport terminal +bug +trim +railway station +briefcase +flat +fountain +person +lane +asparagus +art +lantern +dishwasher +director +snake +lecture +game controller +tree branch +pub +bathing suit +queue +belly +poppy +bow +pitcher +ice cream cone +cave +candy +road bridge +host +traffic jam +earring +file +foot +watermark overlay stamp +mailbox +supercar +railing +bedroom +seafood +waffle +bronze statue +plan +flow +marble +basketball game +automobile +scene +cypress tree +soldier +skateboarder +glass building +cherry tree +pump +grain +wildebeest +loop +frame +bathtub +saxophone +diver +stalk +lily +bead +alley +flock +family room +manufacturing +pointer +worker +navy +potato +teacher +photography +dolly +boardwalk +water fountain +athlete +side dish +bay +ice hockey +phone +hero +face +gold medal +blind +swamp +researcher +swim +meatball +iguana +leather jacket +jellyfish +site +smoke +traffic signal +melon +beetle +calculator +skirt +plantation +sculptor +barrier +catcher +security guard +sketch +awning +steering wheel +mountain view +bus stop +pool +leg +spotlight +apron +mineral +inlet +sleeve +torch +emotion +march +police officer +performance +lamp post +fishing boat +summer +presentation +saucer +suitcase +supermodel +goalkeeper +shrub +rock artist +document +beach house +man +blue artist +cigar +railroad track +gown +mosaic +bungalow +alphabet +baseball field +shed +pedestrian +rail +soap +kitchen counter +dessert +dunk +blossom +conversation +fruit market +glass jar +military +beer bottle +photographer +tennis racket +competition +escalator +bell tower +stilt +ballerina +television +feather +fence post +rear +dahlia +red carpet +tub +hole +fortress +pack +telephone +cardboard +city park +platform +college student +arch bridge +wind +blender +bloom +ice rink +birthday +raven +fairy +embankment +hall +flower shop +suburb +barrel +biker +steam +dragonfly +formation +electricity +business people +symmetry +walkway +fisherman +gas mask +loch +youth +hanger +dot +fish +street market +animation film +crime fiction film +boar +emblem +halloween costume +kangaroo +couple +spoon +squirrel +neon sign +sky +office desk +beauty salon +breakwater +fashion look +toaster +author +news conference +outdoor +canoe +dragon +tool +shopping centre +ladybug +swimming pool +landscaping +ski pole +red +truck +fly +temple +level +sunday +railroad bridge +car mirror +lawn mower +flute +aircraft carrier +fashion menswear london week +sunshine +tile floor +skull +fossil +flower arrangement +diaper +sea turtle +cherry blossom +fireman +shack +lens +waiter +animal +basement +snow +autumn park +glass box +kick +head +anniversary +vine +back +paper lantern +fish tank +cellphone +silk +coral +notebook +photo +gazebo +ketchup +driver +farmer +bonfire +chestnut +photoshoot +football field +olive tree +pheasant +sandal +toilet +fireplace +music +deity +fish market +fig +bell +neck +grave +villa +cyclist +crate +grey +asphalt road +soccer +hostel +municipality +courthouse +roof +end table +pot +sedan +structure +folk artist +sport +sport team +protest +syringe +fashion designer +jersey +heart shape +kayak +stare +sit with +direct +read +photograph +spin +teach +laugh +carve +grow on +warm +watch +stretch +smell +decorate +shine +light +dance +send +park +chase +collect +lead +kiss +lead to +lick +smile +cheer +sit +point +block +rock +drop +cut +ski +wrap +lose +serve +provide +sleep +dress +embrace +burn +pack +stir +create +touch +wash +stick +reveal +shop +train +paint +groom +hunt +bloom +play +pay +brush +shoot +hold +picture +carry +sip +contain +turn +pour +pitch +give +add +blow +look in +show +walk +illuminate +kneel +cover +drag +post +present +fit +operate +fish +race +write +deliver +peel +push +run +sit around +buy +jump +walk on +attend +clean +sell +ride on +mount +host +dry +plant +sing +row +shake +perch +ride +fight +skateboard +live +call +surround +practice +play on +work on +step +relax +hit +fall in +flow +greet +launch +wear +hang on +drive +sit in +break +learn +fly +connect +display +locate +compete +go for +sail +lift +toast +help +run on +reflect +pose +scratch +frame +dribble +herd +enter +exit +place +inspect +build +pick +fill +grind +skate +offer +float +sit by +stand +release +rest +singe +climb +tie +mark +lay +stand around +capture +set +land +swinge +run in +kick +lean +head +sign +approach +swim +close +crash +control +fall +remove +repair +open +appear +travel +load +miss +check +surf +moor +smoke +drink +board +seat +feed +rise +sit on +swing +grow +strike +date +slide +share +graze +jump in +lie +extrude +roll +move +gather +eat +pull +run through +squeeze +lay on +draw +play with +wave +assemble +perform +march +score +attach +adjust +hang +hug +sleep on +throw +live in +talk +pet +work +run with +see +flip +catch +cook +receive +celebrate +look +classic +bridal +indoor +industrial +teenage +mini +grassy +aged +long +warm +light +handsome +happy +three +pregnant +circular +urban +silver +ceramic +3d +green +blonde +golden +dark +tropical +ripe +deep +fat +musical +giant +medical +medieval +bare +stunning +bold +geographical +huge +plastic +foggy +stormy +gothic +biological +empty +clear +antique +pink +steep +brown +striped +aerial +rainy +cool +flying +commercial +purple +trendy +blank +haired +dead +wooden +flat +high +beige +panoramic +angry +dozen +rural +solar +big +small +stained +thick +many +fresh +clean +strong +abstract +crowded +retro +dry +gorgeous +martial +modern +blue +cloudy +low +four +outdoor +single +much +beautiful +snowy +pretty +new +short +sunny +closed +rocky +red +two +double +male +gray +five +colorful +automotive +various +one +old +rusty +tall +wild +narrow +natural +several +frozen +textured +lush +young +hot +mixed +white +float +quiet +round +bright +religious +female +historical +shiny +traditional +tourist +yellow +bald +coastal +lovely +little +broken +romantic +wide +royal +rich +open +cute +ancient +cold +political +elderly +gold +full +rustic +metallic +floral +sad +wet +fancy +senior +tiny +stylish +large +frosty +orange +transparent +electronic +shallow +scared +armed +dirty +historic +black +few +windy +some +square +ornamental +sandy +thin \ No newline at end of file diff --git a/ppcls/utils/ram/tag_list.txt b/ppcls/utils/ram/tag_list.txt new file mode 100644 index 0000000000..11a61b68fb --- /dev/null +++ b/ppcls/utils/ram/tag_list.txt @@ -0,0 +1,3429 @@ +tennis +bear cub +observatory +bicycle +hillside +judge +watercolor illustration +granite +lobster +livery +stone +ceramic +ranch +cloth +smile +building +tattoo +cricketer +cheek +pear +source +winter +surface +spray +ceremony +magic +curve +container +fair +medicine +baby +tennis racquet +ornament +bamboo +duckling +song +safari +team presentation +daffodil +cross +toothpaste +shield +fashion model +capsule +map +creek +glass house +glass plate +siding +corner +water buffalo +bison +figure skater +diploma +tire +race +cable car +brain +gas stove +soap bubble +palette +snowboard +school child +trench coat +monk +fiber +kitchen window +sunglass +coffee +security +strawberry +penguin +tree root +loaf +engagement ring +lamb +vector cartoon illustration +sandwich +mountain village +shape +charm +fiction +knot +greenhouse +sushi +text +disaster +trophy +gang +strap +soccer game +cardinal +tee +turtle +water surface +grassland +dolphin +store +dirt +iceberg +pergola +farmer market +publicity portrait +tote bag +teenage girl +view mirror +session +commuter +dressing room +tricycle +christmas ball +headlight +police +armchair +chart +yacht +saw +printer +rock band +gingerbread house +tag +table lamp +hockey game +slope +font +wicker basket +jewelry +quarter +software +weapon +pin +worship +painter +goal +morning light +bike +baseball bat +elevator +cuisine +sausage +stunt +wrestler +statue +landing +pillar +willow tree +sea wave +chicken +peanut +muscle +bob +tv genre +bathroom window +radish +textile +pelican +marketplace +crest +elevation map +gift +parish +traffic light +campfire +fog +award winner +beach ball +mat +white house +plaster +moped +football team +solution +bicyclist +bit +playground +darkness +cake +maple leave +mold +cracker +blueberry +rubble +container ship +pedestrian bridge +snail +parrot +form +circuit +highlight +pickup truck +koala +rain +system +weather +raincoat +soccer team +windshield +thunderstorm +mike +bird house +bridge +grandfather +restroom +animation +wilderness +clown +banana +brown +braid +dining room +kindergarten +launch event +purple +school +stairwell +brooch +movie poster image +mountain river +shelf +wicket +headboard +buddha +flower field +dugout +cd +bald eagle +lagoon +seaweed +agriculture +emergency service +maple tree +parachute +continent +amusement park +remote +bun +tackle +hospital +garage door +birthday party +friendship +go +mausoleum +jeep +raccoon +step +ice hockey team +cigarette +lace dress +forest floor +mall +captain +milk +golf course +meal +picnic table +sail +volleyball +canal +terrace +computer desk +caravan +hotel +cheerleader +nurse +museum +marsh +fox +plateau +night +twin +letter logo +autumn tree +powder +convention +creature +lighthouse +shop window +jacket +stork +taxi +trade +blackboard +olive +road sign +resort +snowflake +cemetery +travel +evening dress +picnic +drink +winter morning +football player +snack +boxing glove +dinner party +airline +swing +port +wheelbarrow +bathroom sink +sweater +ambulance +gear +oil +wii controller +array +home office +car show +mixture +profession +tree frog +square +facility +coral reef +sea wall +pizza +exhibit +demolition +trout +ring +coffee shop +bracelet +bean +lip +fencing +landscape +sitting +package +metal +bust +king +hair +window seat +wildlife +trunk +greenery +stencil +fire hydrant +bridesmaid +plaza +alps +tower bridge +crop top +crossing +cinema +pedestrian crossing +family +shopping cart +stomach +church building +screen door +skater +soccer field +kettle +mussel +raindrop +candy cane +water lily +flower girl +desert +enclosure +christmas light +kitchen +caterpillar +plaid +bath +bush +mud +ballet +knee +adult +raft +sea view +cactus +office chair +overall +rim +scaffolding +pig +cover +poster page +sprinkle +chandelier +algae +traffic +surfboard +book +filming +flash +mansion +camouflage +trouser +ticket +weed +cab +trench +elephant +huddle +sphere +christmas decoration +city +launch +doll +christmas ornament +fabric +bikini +biplane +breakfast +neighbourhood +race track +foliage +avocado +school bus +footwear +highway +ocean view +art vector illustration +wall clock +curtain +teenager +kitchen area +robot +tusk +lounge chair +beam +paddle +camel +lid +world map +city view +newlywed +cargo ship +yellow +exhibition +bend +novel +wool +ontario +bread +campus +coastline +cutting board +booth +table top +carpet +beach chair +workout +street food +fun +costumer film designer +gadget +artist +fishing village +builder +violinist +iphone +spider web +traffic sign +ruin +rescue +clipboard +seal +film director +paw +nursery +intersection +tomato sauce +taste +paddy field +christmas tree +wave +stool +watering can +rug +daytime +subway station +craft +pine forest +black +planet +motif +christmas market +glass window +college +wheat +damage +rectangle +picture frame +chess +guest room +street corner +religion +seed +puzzle +freeway +beauty +ocean +watch +mother +garage +quote +dj +supporter +hip hop artist +muffin +eiffel tower +cash +firefighter +cauliflower +bunker +sled +manicure +shark +stall +jungle +family home +tour bus +chimney +touchdown +roundabout +coyote +street scene +tank +wedding dress +mantle +bedroom window +coconut +chapel +goat +living space +rock wall +polka dot +railway +mandala +mango +lesson +mountain landscape +team photo +bookshelf +meter +bulldog +evening sun +stick +card +pink +fish pond +paint +pill +cart +pea +van +album +football college game +mountain pass +doughnut +ski slope +match +official +shadow +organ +celebration +coin +log cabin +firework display +present +twig +chef +confetti +footpath +tour +ponytail +artwork +race car +club +season +hose +pencil +aircraft +rock formation +wardrobe +participant +politician +engineer +peace +filter +sailing boat +water bottle +service dog +poodle +loki +statesman +sleeping bag +outskirt +clock +factory +oak tree +physician +color +room +stairway +company +lady +graph +faucet +tablecloth +subway train +chocolate chip cookie +headquarters +screw +goggle +halloween +city street +swirl +cord +forward +bone +bedding +archway +wig +lobby +mask +attic +kitchen table +skylight +fire +exit +oil painting +passenger +meditation +salmon +fedora +rubber stamp +orange juice +arch +scientist +stroll +manhattan +float +baseball uniform +circle +church +decker bus +competitor +zoo +basketball team +tourist +daughter +silverware +ceiling fan +birth +vase +jack +mushroom +spiral +cage +limb +salad +ad +control +earth +party +bolt +tractor +barley +wedding photo +hawk +warehouse +vegetable garden +chocolate cake +cabbage +floor window +baby shower +magnifying glass +table +stethoscope +reading +mission +croissant +gift box +rocket +forest road +cooking +suite +hill country +motorcycle +baseball player +angle +drug +sport association +championship +family portrait +florist +softball +egret +office +plywood +jockey +mosque +brunch +beanie +office building +pattern +calendar +indoor +pepper +ledge +trail +fuel +laptop computer +tennis shoe +deck chair +guitarist +barn +surgery +cartoon illustration +nebula +railroad +mountain goat +goose +car door +cheer +liquid +hardwood floor +pathway +acorn +gull +airliner +couch +lake house +spaghetti +promenade +collection +garden +bank +robin +tennis ball +peony +gymnast +lavender +deck +test +riverside +rapper +domino +bride +mouse +basil +wedding couple +ocean wave +arm +kitchen floor +grove +family member +backyard +raspberry +forest fire +officer +hibiscus +canyon +composer +signature +olive oil +hibiscus flower +rose +vector icon +sunrise +horseback +motor scooter +office worker +tradition +ingredient +washing machine +lighting +bagel +sailboat +policeman +mare +graphic +halloween pumpkin +stock +pilot +education +team +body +horse +kimono +bazaar +bag +recording studio +parsley +entrance +denim +vet +horse farm +charcoal +architecture +glass vase +puppy +estuary +television show host +city bus +shoulder +beast +balance +golfer +roadside +denim jacket +stone wall +counter top +app icon +toast +head coach +ham +warrior +gem +refrigerator +snowman +construction worker +coal +website +morning fog +mustard +human +owl +puppy dog +piggy bank +vegetation +pirate +action film +marshmallow +thanksgiving +business +disease +signage +greeting +skate park +tile +mouth +spinach +vacation +leader +shrine +walker +science fiction film +bill +rabbit +motor boat +bar +radio +barge +tail +chainsaw +gallery +rainbow +pasta +padlock +web +pastry +ink +reef +school uniform +shawl +treasure +peach +dinner table +injury +harbor +witch +car dealership +litter +gesture +documentary +marriage +sea shell +priest +dome +kit +icon +seaside +bucket +entertainment +stable +hat +puddle +sock +shopper +technology +harbour +orbit +antler +tube +flag waving +cook +tight +commander +farmland +switch +hiker +wedding ceremony +award ceremony +champion +chopstick +farmhouse +performer +spike +accident +cruise ship +passenger train +attraction +entertainer +rear view +sidewalk +parade +racing +plane +ritual +peacock +pocket +plum +drop +carrot +floor +sunset +troop +architect +coffee table +dust +outline +leather +charity event +heat +whale +laundry +coconut tree +crosswalk +pony +ant +pipe +string +coat +angel +beef +church tower +dish +pitch +cupboard +thermometer +dirt field +fireworks +minute +cane +pajama +flower garden +autumn +trash can +dachshund +banana tree +tray +moose +roadway +carnival +antenna +pole +castle wall +ram +cattle +hay +cookie +swimmer +baseball team +strait +hedge +jet +fire pit +octopus +calf +cube +opera +cardboard box +tiara +kitchen sink +prairie +bowl +galaxy +straw hat +linen +ski resort +stitch +street lamp +motorist +icicle +stain +flora +drain +kitchen cabinet +decor +bouquet +pound +interior design +nail polish +figurine +tomb +disc +twist +blouse +ribbon +figure +burger +cork +soccer goalkeeper +train bridge +drinking water +dew +baker +storm cloud +tarmac +tv drama +sponge +magnet +sailor +entry +swan +exercise +sloth +jewel +scuba diver +bite +cat tree +tent +can +tennis match +ecosystem +picket fence +palm +train car +frying pan +rally +tablet pc +reindeer +image +wolf +chin +conservatory +flood water +cityscape +beach sand +car park +pavement +farm field +swimming +winter storm +stem +pillow +inning +gorilla +desk +avenue +fern +money +pearl +train station +skillet +nap +barber +library +freezer +label +rainforest +parking sign +mirror +wing +noodle +press room +sculpture +tablet +viewer +prayer +mini +mechanic +laugh +rice field +hand +mustache +mountain road +catwalk +conference +cape +installation +musician +stream +machine +speech +crocodile +soccer match +town square +passport +post box +point +stone building +motorway +mix +dentist +businessperson +happiness +boat +vineyard +treadmill +glass wall +water droplet +coffee mug +graduate +sunflower +parliament +shepherd +movie +wine +orchard +tulip +motherboard +cup +broom +spot +drawing +polo shirt +graduation +film producer +moonlight +glow +film format +t shirt +rock face +sword +clinic +festival day +meadow +staple +pupil +training ground +rider +flower +foal +wharf +foot bridge +shooting +top +mast +police car +robe +wedding bouquet +stop sign +birthday cake +glitter +butter +scooter +tundra +superhero +pocket watch +inscription +youngster +fruit tree +movie poster +engine +foundation +motorcyclist +take +woman +antelope +country artist +road trip +typewriter +tuxedo +brand +pine +bathroom +paradise +texture +balloon +dining table +home +computer screen +actor +clip +tv tower +panorama +summit +cat +plot +eagle +dancer +pup +studio shot +tear +bird bath +classroom +bookstore +city wall +tv programme +blade +easel +buttercream +sweet +designer +diamond +handshake +herb +corn field +seafront +concrete +street artist +gas +stamp +window display +paper +note +pint +quarry +research +fixture +manager +soil +leopard +board game +ladder +stop light +island +ramp +football match +icing +drill +currency +summer evening +topping +pyramid +pomegranate +cell +ivy +squad +scenery +computer +locomotive +surf +mascot +dune +path +duck +twilight +wire +bow tie +strike +cormorant +car wash +crane +market +philosopher +alarm clock +camera +birch +greeting card +plain +clay +donut +lock +moth +laboratory +fan +violin +jazz fusion artist +mountain biker +terrain +magazine +pickup +comedy film +smartphone +film +bed +microwave oven +tournament +lawn +car window +alligator +screen +jetty +shopping bag +landscape view +cabinetry +friendly match +thing +petal +shopping center +transport +ballet dancer +shoreline +princess +car seat +parking meter +green +vodka +band +rock +costume +warning sign +strip +plaque +wheelchair +headband +ginger +dice +media +hairdresser +press +living room +stove +player +cherry +workshop +carving +embroidery +doodle +adventure +rugby player +monument +brush +marker +loft +postcard +collage +ball +professor +dresser +gig +festival +blackbird +makeup artist +video camera +sticker +peak +wildflower +santa hat +rodeo +wedding photographer +guy +staff +waterfall +operation +defender +falcon +haze +individual +gentleman +greyhound +rocking chair +rice +garbage +platter +chocolate +splash +business suit +cheetah +valley +maze +trampoline +garland +slalom +unicorn +tree stump +painting +romance +fight +alcohol +ghost +fondant +spa +shutter +death +demonstration +cotton +pier +flea market +history +savannah +fist +aisle +crew +jug +pose +anchor +teapot +boat house +business team +tripod +bee +pebble +mattress +canvas +hallway +campaign +pod +lake district +article +white +sofa +honey +marathon +pancake +tourist attraction +wedding gown +battle +shelving +sea +sheet music +pie +yarn +construction site +flyer +tie +star +lettuce +martial artist +dart +straw +reflection +conference room +temperature +rugby +mosquito +physicist +rock climber +crash +backdrop +toilet seat +sand castle +water park +toy car +waste +luxury +hangar +rv +tree trunk +board +gold +project picture +cap +cottage +relief +attire +microscope +battery +roll +line +parking garage +crystal +broadcasting +brick wall +lab +flooring +meeting +3d cg rendering +desktop computer +cowboy +sailing ship +junction +hairstyle +homework +profile +model +flower pot +street light +salt lake +maple +space +blizzard +throw +zebras +brochure +constellation +beak +kilt +pond +blue sky +sneaker +sand dune +morning sun +almond +grill +curl +basketball girl game +chameleon +toilet bowl +prince +keyboard +queen +computer monitor +writing +crown +basilica +kiss +house +parking +football competition +shell +sport equipment +comedy +baboon +vendor +rise building +wrap +food truck +cat bed +rickshaw +flare +teal +nectar +eclipse +vehicle +steam locomotive +gorge +cow +christmas card +demonstrator +memorial +towel +jewellery +train +frisbee +baseball game +fur +afternoon sun +community +sparkler +bandage +firework +dollar +pasture +video +bus +tree house +seashore +field +hamburger +souvenir +hedgehog +worm +pine cone +osprey +dinosaur +vegetable +junk +poster +army +winger +bundle +stage +growth +wedding party +service +blanket +ruler +eye +credit card +castle +diner +hut +elk +hard rock artist +nun +dog breed +nest +drama film +number icon +water tank +giraffe +altar +pavilion +tv personality +suv +street vendor +street sign +ditch +debris +foam +takeoff +spice +mountain lake +tea +orchestra +spacecraft +counter +abbey +mountain +hydrangea +racer +orange tree +tide +cowboy hat +rapid +town +wild +herd +vein +driveway +jar +bark +illustration +horror film +corn +stroller +industry +mountain stream +gym +neckline +pan +client +spectator +eggplant +camper +fawn +hoodie +meat +lemonade +food market +slum +comic book character +flower market +love +palace +gun +heel +shopping street +shooting basketball guard +family photo +rooftop +laundry basket +airport runway +horn +face mask +flight +appetizer +violet +country lane +cement +instrument +tv actor +spark +celebrity +award +country house +standing +auction +date +engagement +puck +advertisement +chair +zebra +driftwood +bumblebee +maple leaf +bonnet +orange +water tower +door +singer +floor plan +discussion +theatre +pilgrim +mug +branch +window sill +baseball pitcher +bakery +lollipop +basketball player +toilet paper +chalkboard +cabin +sign +night sky +cannon +fishing net +submarine +suit +fur coat +wine bottle +folder +street art +suspension bridge +evening sky +billboard +postage stamp +newspaper +transportation +surgeon +light +park +horizon +road +sand bar +trumpet +lounge +cloud forest +birthday celebration +balcony +anime +beehive +umbrella +goldfish +baseball cap +waterhole +ceiling +carousel +backpack +plant pot +atmosphere +sunflower field +spire +vision +woodpecker +chip +pool table +lotus flower +cone +humpback whale +reservoir +hunt +piano +plate +dining area +luggage +skier +dance floor +crow +stair +overpass +opera house +bear +jazz artist +water +vessel +cast +yard +cathedral +basketball hoop +graveyard +sound +berry +onlooker +fauna +birch tree +retail +hill +skeleton +journalist +frost +basket +nail +dusk +trash +dawn +clover +hen +volcano +basketball coach +home decor +charge +haircut +sense +university +lizard +daisy +tablet computer +grass field +prison +metal artist +bathroom mirror +window frame +chest +flavor +pop country artist +market square +monkey +blog +deer +speech bubble +dog +independence day +girl +boy +tartan +furniture +appliance +office window +fish boat +sand box +tv sitcom +drama +sleigh +depression +paper towel +baseball +protestor +grape +wedding cake +invitation +accessory +pick +grandparent +racket +tea plantation +outdoors +egg +glass bowl +sun +organization +lion +panel +station +wallpaper +helicopter +salt +vanity +patio +lunch +street performer +mountain range +soup +bacon +power station +cantilever bridge +hummingbird +shirt +rope +hip +chalk +pendant +choir +tv +lichen +railway bridge +art gallery +bartender +wagon +baby elephant +accordion +horseshoe +building site +clutch +harvest +savanna +geranium +business woman +paddock +patch +beech tree +war +suburbs +hospital bed +motorcycle racer +moss +gravel +government agency +dollar bill +father +fjord +concert +nut +wedding photography +finish line +home plate +food +nose +thumb +village +dining room table +bumper +monster +blackberry +lime +conflict +gala +wallet +wrist +hug +mermaid +lava +lawyer +folk rock artist +arena +onion +toothbrush +fashion +perfume +flip +triangle +woodland +mail +grasshopper +studio +wood floor +den +racquet +cello +lemur +astronaut +glass table +blood +dvd +planter +silver +leash +master bedroom +forest +batter +shoe +engraving +opening +product +toe +cocktail +mallard duck +bike ride +oasis +wedding ring +cinematographer +holly +autograph +fence +ice cube +cove +pineapple +aurora +glass bead +produce +apartment building +cob +miniature +cockpit +flashlight +frog +sheep +groom +steel +watermelon +clip art +paper plate +ostrich +contour +mural +cub +paisley bandanna +winery +turn +handle +satellite +post +pork +child +asphalt +grocery store +vulture +trolley +nightclub +brick +trailer +compass +cereal +cafe +cartoon character +sugar +fiction book +glass floor +umpire +guitar +hamster +protester +airplane +garment +blazer +railway line +wedding +shoe box +parking lot +construction +graduation ceremony +tram +telescope +copper +pain +autumn forest +guest house +partner +crayon +dip +boot +corridor +computer keyboard +hockey player +chicken coop +bus station +gathering +ankle +bunk bed +wood table +football coach +monarch +pharmacy +legging +mannequin +female +train track +stack +canopy +design element +grandmother +symbol +beach hut +zucchini +bomb +businessman +skyscraper +tongue +case +sparkle +highland +ballroom +prom +estate +customer +archipelago +cheese +debate +carriage +bulldozer +pumpkin +sitting room +gas station +wedding reception +camp +dog bed +tower +property +river bed +pop latin artist +fridge +wine glass +coast +beer +tow truck +fire truck +mountain bike +thigh +heron +boat ride +gondola +turquoise +lake +llama +kitty +tin +waiting room +coffee cup +socialite +guard +tap +waterway +forehead +list +erosion +box +sea lion +pollen +dam +wasp +salon +tennis tournament +flower box +aquarium +rain cloud +clothing store +lead singer +cupcake +tortoise +lettering +sport facility +dance +dog house +nature +football +rooster +footballer +railway track +crowd +fishing rod +silhouette +wind turbine +sari +bus window +cloud +charity +medal +yoga +event +veil +fashion menswear milan week +news +knife +print +screen tv +walnut +fungus +ice cream +computer mouse +play +tribe +picture +video game +business card +music festival +rack +envelope +shower +dirt road +mine +oyster +monarch butterfly +dude +fruit salad +podium +fork +lace +test match +boulder +cricket player +staircase +peninsula +shopping +popcorn +oak +market stall +pine tree +mountaineer +student +closet +hood +handstand +centerpiece +insect +patient +makeover +tennis player +sheet +park bench +apple +organism +hook +turkey +tangerine +sibling +shopping mall +bird +scarf +smoothie +net +grass +napkin +ray +eyebrow +laptop keyboard +motorbike +woman hand +oven +book cover +easter egg +microwave +sand +snapshot +soccer ball +makeup +knight +bowling ball +shower curtain +flame +lightning +running +power plant +crib +cartoon +moat +fashion girl +wedding invitation +bottle +cliff +monastery +file photo +apartment +casino +cream +sweatshirt +storm +cruise +teddy bear +shovel +wind farm +writer +dock +professional +hotel room +job +monitor +donkey +pass +interview +duchess +mark +plank +beard +zombie +trio +channel +cricket team +windmill +vest +diagram +cable +winter scene +golden gate bridge +buffalo +studio portrait +pagoda +whiskey +freight train +kite +future +steam train +phone box +headset +wood +snowboarder +paper bag +slide +grapefruit +seating +morning +bronze sculpture +theatre actor +stump +jean +landmark +jam +waist +watercolor +hammock +light fixture +ice +basin +beverage +shelter +premiere +mound +ear +bronze +sunlight +street +energy +barn door +hike +fleet +claw +beach +pepperoni +bin +trainer +buffet +archive +toddler +referee +bay window +dove +production company +evening light +gate +farm +reed +fruit stand +explorer +snow storm +throw pillow +button +display case +bookcase +lead +lipstick +basketball court +cargo +ensemble +pope +clock tower +teen +speaker +rat +laptop +ski +mess +stadium +ferry boat +bunny +waterfront +downtown +sink +press conference +dinner +condiment +thread +audience +grid +car +plastic +people +barbecue +pigeon +urinal +seagull +volunteer +hockey +fir tree +pollution +trial +collar +area +meeting room +circus +yogurt +orangutan +viaduct +comedian +drone +scissor +pop rock artist +biscuit +panda +water feature +air balloon +remote control +watercolor painting +show +walk +post office +bike path +rap gangsta artist +microphone +crack +sunset sky +glass +tv show +cartoon style +stripe +foyer +signal +calligraphy +bulb +gardener +coffee bean +spider +tapestry +city skyline +necklace +kitten +traveler +veteran +frosting +fry +tennis court +tank top +butterfly house +mist +drummer +water level +scale +baseball glove +music video performer +champagne +camping +clothing +water drop +telephone box +pen +morning mist +fire engine +porch +opening ceremony +style +palm tree +fashion show +universe +scratch +axe +ottoman +explosion +rib +boutique +game +cucumber +fruit +stone bridge +nature reserve +track +train window +punch +telephone pole +velvet +sauce +moon +contrast +flamingo +bat +vending machine +ship +equestrian +shade +comforter +pallet +sparrow +wii +glaze +grocery +steeple +soccer player +contract +advertising +runner +chimpanzee +world +seat +project +chihuahua +bubble +willow +pedestal +soul hip hop artist +curb +drawer +leaf +banner +launch party +coach +government +snowball +toy +portrait +doctor +whiteboard +electronic +tiger +graffiti +column +nightstand +whistle +maxi dress +bench +wetsuit +bird feeder +football game +basketball +class +bathroom door +store window +text message +wreath +street view +binocular +pet +facade +drought +lemon +new year +night view +airplane window +specie +rule +jaw +wheat field +diet +pop artist +habitat +screenshot +scoreboard +shore +mane +quilt +ski lift +orchid +turban +christmas +airport +marina +glass door +glass bottle +restaurant +conductor +logo +sleep +tape +tomato +river bank +lilac +tooth +training +pottery +shop +steam engine +mason jar +base +procession +border +shoot +footprint +hotdog +bull +stocking +recreation +automobile model +design +country pop artist +river +retriever +department store +auditorium +sport car +supermarket +belt +cricket +window box +dress shirt +letter +residence +megaphone +pant +wildfire +bird nest +crab +swimsuit +candle +funeral +mill +national park +plant +cop +power line +perch +blue +finger +ferris wheel +globe +skateboard +helmet +movie theater +uniform +hammer +material +kid +well +butterfly +sideline +fashion fall show +planet earth +lift +male +sauna +gray +flour +sand sculpture +program +cabinet +infant +wheel +aircraft model +dough +garlic +skate +arrow +wrapping paper +ripple +lamp +iron +banknote +beaver +ferry +courtyard +bassist +countryside +steak +comfort +boxer +laundry room +campsite +brick building +golf +subway +headphone +fort +handbag +drum +flood +saddle +bass +labyrinth +needle +sun ray +app +menu +president +cardigan +dandelion +wetland +ice hockey player +number +city hall +fishing +portrait session +pug +key +art print +minister +hurdle +emergency +painting artist +flag pole +evening +purse +recipe +golf ball +coloring book +mountain peak +senior +holiday +bud +cousin +pantry +lap +skin +flag +tissue paper +ridge +wire fence +surfer +climber +photograph +sewing machine +cooler +actress +apple tree +cancer +starfish +automobile make +dumbbell +brace +tunnel +window +paint artist +composition +school student +condo +convertible +cushion +selfie +territory +guide +tree +court +shrimp +stone house +dress +eyelash +juice +broccoli +chain +tourism +mountain top +concept car +film premiere +light bulb +cafeteria +badge +flower bed +theater +root +racecar driver +basketball boy game +glove +skyline +wall +glacier +airport terminal +bug +trim +railway station +briefcase +flat +fountain +person +lane +asparagus +art +lantern +dishwasher +director +snake +lecture +game controller +tree branch +pub +bathing suit +queue +belly +poppy +bow +pitcher +ice cream cone +cave +candy +road bridge +host +traffic jam +earring +file +foot +watermark overlay stamp +mailbox +supercar +railing +bedroom +seafood +waffle +bronze statue +plan +flow +marble +basketball game +automobile +scene +cypress tree +soldier +skateboarder +glass building +cherry tree +pump +grain +wildebeest +loop +frame +bathtub +saxophone +diver +stalk +lily +bead +alley +flock +family room +manufacturing +pointer +worker +navy +potato +teacher +photography +dolly +boardwalk +water fountain +athlete +side dish +bay +ice hockey +phone +hero +face +gold medal +blind +swamp +researcher +swim +meatball +iguana +leather jacket +jellyfish +site +smoke +traffic signal +melon +beetle +calculator +skirt +plantation +sculptor +barrier +catcher +security guard +sketch +awning +steering wheel +mountain view +bus stop +pool +leg +spotlight +apron +mineral +inlet +sleeve +torch +emotion +march +police officer +performance +lamp post +fishing boat +summer +presentation +saucer +suitcase +supermodel +goalkeeper +shrub +rock artist +document +beach house +man +blue artist +cigar +railroad track +gown +mosaic +bungalow +alphabet +baseball field +shed +pedestrian +rail +soap +kitchen counter +dessert +dunk +blossom +conversation +fruit market +glass jar +military +beer bottle +photographer +tennis racket +competition +escalator +bell tower +stilt +ballerina +television +feather +fence post +rear +dahlia +red carpet +tub +hole +fortress +pack +telephone +cardboard +city park +platform +college student +arch bridge +wind +blender +bloom +ice rink +birthday +raven +fairy +embankment +hall +flower shop +suburb +barrel +biker +steam +dragonfly +formation +electricity +business people +symmetry +walkway +fisherman +gas mask +loch +youth +hanger +dot +fish +street market +animation film +crime fiction film +boar +emblem +halloween costume +kangaroo +couple +spoon +squirrel +neon sign +sky +office desk +beauty salon +breakwater +fashion look +toaster +author +news conference +outdoor +canoe +dragon +tool +shopping centre +ladybug +swimming pool +landscaping +ski pole +red +truck +fly +temple +level +sunday +railroad bridge +car mirror +lawn mower +flute +aircraft carrier +fashion menswear london week +sunshine +tile floor +skull +fossil +flower arrangement +diaper +sea turtle +cherry blossom +fireman +shack +lens +waiter +animal +basement +snow +autumn park +glass box +kick +head +anniversary +vine +back +paper lantern +fish tank +cellphone +silk +coral +notebook +photo +gazebo +ketchup +driver +farmer +bonfire +chestnut +photoshoot +football field +olive tree +pheasant +sandal +toilet +fireplace +music +deity +fish market +fig +bell +neck +grave +villa +cyclist +crate +grey +asphalt road +soccer +hostel +municipality +courthouse +roof +end table +pot +sedan +structure +folk artist +sport +sport team +protest +syringe +fashion designer +jersey +heart shape +kayak +stare +sit with +direct +read +photograph +spin +teach +laugh +carve +grow on +warm +watch +stretch +smell +decorate +shine +light +dance +send +park +chase +collect +lead +kiss +lead to +lick +smile +cheer +sit +point +block +rock +drop +cut +ski +wrap +lose +serve +provide +sleep +dress +embrace +burn +pack +stir +create +touch +wash +stick +reveal +shop +train +paint +groom +hunt +bloom +play +pay +brush +shoot +hold +picture +carry +sip +contain +turn +pour +pitch +give +add +blow +look in +show +walk +illuminate +kneel +cover +drag +post +present +fit +operate +fish +race +write +deliver +peel +push +run +sit around +buy +jump +walk on +attend +clean +sell +ride on +mount +host +dry +plant +sing +row +shake +perch +ride +fight +skateboard +live +call +surround +practice +play on +work on +step +relax +hit +fall in +flow +greet +launch +wear +hang on +drive +sit in +break +learn +fly +connect +display +locate +compete +go for +sail +lift +toast +help +run on +reflect +pose +scratch +frame +dribble +herd +enter +exit +place +inspect +build +pick +fill +grind +skate +offer +float +sit by +stand +release +rest +singe +climb +tie +mark +lay +stand around +capture +set +land +swinge +run in +kick +lean +head +sign +approach +swim +close +crash +control +fall +remove +repair +open +appear +travel +load +miss +check +surf +moor +smoke +drink +board +seat +feed +rise +sit on +swing +grow +strike +date +slide +share +graze +jump in +lie +extrude +roll +move +gather +eat +pull +run through +squeeze +lay on +draw +play with +wave +assemble +perform +march +score +attach +adjust +hang +hug +sleep on +throw +live in +talk +pet +work +run with +see +flip +catch +cook +receive +celebrate +look +classic +bridal +indoor +industrial +teenage +mini +grassy +aged +long +warm +light +handsome +happy +three +pregnant +circular +urban +silver +ceramic +3d +green +blonde +golden +dark +tropical +ripe +deep +fat +musical +giant +medical +medieval +bare +stunning +bold +geographical +huge +plastic +foggy +stormy +gothic +biological +empty +clear +antique +pink +steep +brown +striped +aerial +rainy +cool +flying +commercial +purple +trendy +blank +haired +dead +wooden +flat +high +beige +panoramic +angry +dozen +rural +solar +big +small +stained +thick +many +fresh +clean +strong +abstract +crowded +retro +dry +gorgeous +martial +modern +blue +cloudy +low +four +outdoor +single +much +beautiful +snowy +pretty +new +short +sunny +closed +rocky +red +two +double +male +gray +five +colorful +automotive +various +one +old +rusty +tall +wild +narrow +natural +several +frozen +textured +lush +young +hot +mixed +white +float +quiet +round +bright +religious +female +historical +shiny +traditional +tourist +yellow +bald +coastal +lovely +little +broken +romantic +wide +royal +rich +open +cute +ancient +cold +political +elderly +gold +full +rustic +metallic +floral +sad +wet +fancy +senior +tiny +stylish +large +frosty +orange +transparent +electronic +shallow +scared +armed +dirty +historic +black +few +windy +some +square +ornamental +sandy +thin \ No newline at end of file diff --git a/tools/__init__.py b/tools/__init__.py index cdefb80509..634230de88 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import infer \ No newline at end of file +from . import infer diff --git a/tools/export_model_multimodal.py b/tools/export_model_multimodal.py new file mode 100644 index 0000000000..a786fcf748 --- /dev/null +++ b/tools/export_model_multimodal.py @@ -0,0 +1,33 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) + +import paddle +import paddle.nn as nn + +from ppcls.utils import config +from ppcls.engine.engine_multimodal import EngineMultimodal + +if __name__ == "__main__": + args = config.parse_args() + config = config.get_config( + args.config, overrides=args.override, show=False) + if config["Arch"].get("use_sync_bn", False): + config["Arch"]["use_sync_bn"] = False + engine = EngineMultimodal(config, mode="export") + engine.export() diff --git a/tools/infer_multimodal.py b/tools/infer_multimodal.py new file mode 100644 index 0000000000..861a01b160 --- /dev/null +++ b/tools/infer_multimodal.py @@ -0,0 +1,29 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) + +from ppcls.utils import config +from ppcls.engine.engine_multimodal import EngineMultimodal + +if __name__ == "__main__": + args = config.parse_args() + config = config.get_config( + args.config, overrides=args.override, show=False) + engine = EngineMultimodal(config, mode="infer") + engine.infer() diff --git a/tools/train_multimodal.py b/tools/train_multimodal.py new file mode 100644 index 0000000000..c3a9818bb7 --- /dev/null +++ b/tools/train_multimodal.py @@ -0,0 +1,30 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) + +from ppcls.utils import config +from ppcls.engine.engine_multimodal import EngineMultimodal + +if __name__ == "__main__": + args = config.parse_args() + config = config.get_config( + args.config, overrides=args.override, show=False) + config.profiler_options = args.profiler_options + + engine = EngineMultimodal(config, mode="train") + engine.train()