From 22fe8c5e9e4a50219654bb347da43b36592ed475 Mon Sep 17 00:00:00 2001 From: Michael Spencer Jr Date: Mon, 16 Dec 2019 18:42:31 -0600 Subject: [PATCH] Initial commit of C# port - does not build --- ports/win32-csharp/src/App.config | 58 ++++++ ports/win32-csharp/src/Encoder.cs | 94 +++++++++ ports/win32-csharp/src/Form1.Designer.cs | 40 ++++ ports/win32-csharp/src/Form1.cs | 39 ++++ ports/win32-csharp/src/GPT2Generator.cs | 66 ++++++ ports/win32-csharp/src/GPT2Model.cs | 58 ++++++ ports/win32-csharp/src/POC.csproj | 190 ++++++++++++++++++ ports/win32-csharp/src/Program.cs | 22 ++ .../src/Properties/AssemblyInfo.cs | 36 ++++ .../src/Properties/Resources.Designer.cs | 71 +++++++ .../src/Properties/Resources.resx | 117 +++++++++++ .../src/Properties/Settings.Designer.cs | 30 +++ .../src/Properties/Settings.settings | 7 + ports/win32-csharp/src/Sample.cs | 47 +++++ ports/win32-csharp/src/packages.config | 23 +++ 15 files changed, 898 insertions(+) create mode 100644 ports/win32-csharp/src/App.config create mode 100644 ports/win32-csharp/src/Encoder.cs create mode 100644 ports/win32-csharp/src/Form1.Designer.cs create mode 100644 ports/win32-csharp/src/Form1.cs create mode 100644 ports/win32-csharp/src/GPT2Generator.cs create mode 100644 ports/win32-csharp/src/GPT2Model.cs create mode 100644 ports/win32-csharp/src/POC.csproj create mode 100644 ports/win32-csharp/src/Program.cs create mode 100644 ports/win32-csharp/src/Properties/AssemblyInfo.cs create mode 100644 ports/win32-csharp/src/Properties/Resources.Designer.cs create mode 100644 ports/win32-csharp/src/Properties/Resources.resx create mode 100644 ports/win32-csharp/src/Properties/Settings.Designer.cs create mode 100644 ports/win32-csharp/src/Properties/Settings.settings create mode 100644 ports/win32-csharp/src/Sample.cs create mode 100644 ports/win32-csharp/src/packages.config diff --git a/ports/win32-csharp/src/App.config b/ports/win32-csharp/src/App.config new file mode 100644 index 00000000..f78ae81c --- /dev/null +++ b/ports/win32-csharp/src/App.config @@ -0,0 +1,58 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/ports/win32-csharp/src/Encoder.cs b/ports/win32-csharp/src/Encoder.cs new file mode 100644 index 00000000..3957c009 --- /dev/null +++ b/ports/win32-csharp/src/Encoder.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Newtonsoft.Json.Utilities; +using System.IO; +using System.Text.RegularExpressions; + +namespace AIDungeon.net +{ + public class Encoder + { + public IDictionary EncodeDict { get; protected set; } + public IDictionary DecodeDict { get; protected set; } + public string Errors { get; set; } + public IDictionary ByteEncoder { get; set; } + public IDictionary ByteDecoder { get; set; } + public IDictionary, int> BpeRanks { get; set; } + public IDictionary Cache { get; set; } + public Regex Pattern { get; set; } + public Encoder(Dictionary encoder, IEnumerable> bpeMerges, string errors = "replace") + { + EncodeDict = encoder; + DecodeDict = encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); + Errors = errors; + ByteEncoder = BytesToUnicode(); + ByteDecoder = ByteEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); + BpeRanks = bpeMerges.Select(BpeConverter).ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + Cache = new Dictionary(); + Pattern = new Regex(@"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"); + } + + private IDictionary BytesToUnicode() + { + const int start1 = '!'; + const int end1 = '~'; + const int start2 = '¡'; + const int end2 = '¬'; + const int start3 = '®'; + const int end3 = 'ÿ'; + + var dict = new Dictionary(); + for (var i = start1; i <= end1; i++) + { + dict[i] = Convert.ToChar(i); + } + for (var i = start2; i <= end2; i++) + { + dict[i] = Convert.ToChar(i); + } + for (var i = start3; i <= end3; i++) + { + dict[i] = Convert.ToChar(i); + } + + var n = 0; + for (var b = 0; b <= 0xFF; b++) + { + if (dict.ContainsKey(b)) continue; + dict[b] = Convert.ToChar(0x100 + n++); + } + + return dict; + } + + private static KeyValuePair, int> BpeConverter(Tuple tuple, int idx) => + new KeyValuePair, int>(tuple, idx); + + public static Encoder GetEncoder(string modelName, string modelsDir) + { + var encoder = + JsonConvert.DeserializeObject>( + File.ReadAllText(Path.Combine(modelsDir, modelName, "encoder.json"))); + + var bpeMerges = new List>(File + .ReadAllLines(Path.Combine(modelsDir, modelName, "vocab.bpe")).Skip(1).Select(SplitWhitespace) + .Where(NotNull)); + + return new Encoder(encoder, bpeMerges); + } + + private static Tuple SplitWhitespace(string input) + { + var output = input.Split(' '); + if (output.Length != 2) return null; + return new Tuple(output[0], output[1]); + } + + private static bool NotNull(Tuple input) => input != null; + } +} diff --git a/ports/win32-csharp/src/Form1.Designer.cs b/ports/win32-csharp/src/Form1.Designer.cs new file mode 100644 index 00000000..2b47e99e --- /dev/null +++ b/ports/win32-csharp/src/Form1.Designer.cs @@ -0,0 +1,40 @@ +namespace AIDungeon.net +{ + partial class Form1 + { + /// + /// Required designer variable. + /// + private System.ComponentModel.IContainer components = null; + + /// + /// Clean up any resources being used. + /// + /// true if managed resources should be disposed; otherwise, false. + protected override void Dispose(bool disposing) + { + if (disposing && (components != null)) + { + components.Dispose(); + } + base.Dispose(disposing); + } + + #region Windows Form Designer generated code + + /// + /// Required method for Designer support - do not modify + /// the contents of this method with the code editor. + /// + private void InitializeComponent() + { + this.components = new System.ComponentModel.Container(); + this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Font; + this.ClientSize = new System.Drawing.Size(800, 450); + this.Text = "Form1"; + } + + #endregion + } +} + diff --git a/ports/win32-csharp/src/Form1.cs b/ports/win32-csharp/src/Form1.cs new file mode 100644 index 00000000..e14b7dfc --- /dev/null +++ b/ports/win32-csharp/src/Form1.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Data; +using System.Drawing; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Windows.Forms; +using Tensorflow; +using Tensorflow.Contrib.Train; +using static Tensorflow.Binding; + +namespace AIDungeon.net +{ + public partial class Form1 : Form + { + public Form1() + { + InitializeComponent(); + + var config = new ConfigProto {GpuOptions = new GPUOptions {AllowGrowth = true}}; + var sess = tf.Session(config); + var context = tf.placeholder(tf.int32, new TensorShape(1)); + + var output = SampleSequence() + + var saver = tf.train.Saver(); + var ckpt = tf.train.latest_checkpoint(@"F:\src\AIDungeon\generator\gpt2\models\model_v5"); + saver.restore(sess, ckpt); + } + + public object SampleSequence(HParams hparams, int length, string start_token, int batch_size, Tensor context, + int temperature, int top_k, int top_p) + { + + } + } +} diff --git a/ports/win32-csharp/src/GPT2Generator.cs b/ports/win32-csharp/src/GPT2Generator.cs new file mode 100644 index 00000000..31fc44c2 --- /dev/null +++ b/ports/win32-csharp/src/GPT2Generator.cs @@ -0,0 +1,66 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Newtonsoft.Json; +using Tensorflow; +using static Tensorflow.Binding; + +namespace AIDungeon.net +{ + public class GPT2Generator + { + public static readonly Random rng = new Random(); + + public int GenerateNum { get; set; } + public decimal Temp { get; set; } + public decimal TopK { get; set; } + public decimal TopP { get;set; } + public bool Censor { get; set; } + public string ModelName { get; set; } + public string ModelDir { get; set; } + public string CheckpointPath { get; set; } + public int BatchSize { get; set; } + public int Samples { get; set; } + public Encoder Encoder { get; set; } + public Session Session { get; set; } + public Tensor Context { get; set; } + public object Output { get; set; } + + + public GPT2Generator(int generateNum, decimal temperature, decimal topK, decimal topP, bool censor) + { + GenerateNum = generateNum; + Temp = temperature; + TopK = topK; + TopP = topP; + Censor = censor; + + ModelName = "model_v5"; + ModelDir = "generator/gpt2/models"; + CheckpointPath = Path.Combine(ModelDir, ModelName); + + var modelsDir = Environment.ExpandEnvironmentVariables(ModelDir); + BatchSize = 1; + Samples = 1; + + Encoder = Encoder.GetEncoder(ModelName, modelsDir); + var hParams = + JsonConvert.DeserializeObject>( + File.ReadAllText(Path.Combine(modelsDir, ModelName, "hparams.json"))); + foreach (var kvp in GPT2Model.DefaultHParams()) + { + if (!hParams.ContainsKey(kvp.Key)) hParams[kvp.Key] = kvp.Value; + } + + var seed = rng.Next(100001); + + var config = new ConfigProto {GpuOptions = new GPUOptions {AllowGrowth = true}}; + Session = tf.Session(config); + Context = tf.placeholder(tf.int32, new TensorShape(BatchSize)); + Output = Sample.SampleSequence(); + } + } +} diff --git a/ports/win32-csharp/src/GPT2Model.cs b/ports/win32-csharp/src/GPT2Model.cs new file mode 100644 index 00000000..7d5861a1 --- /dev/null +++ b/ports/win32-csharp/src/GPT2Model.cs @@ -0,0 +1,58 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; + +namespace AIDungeon.net +{ + public class GPT2Model + { + public static Dictionary DefaultHParams() + { + return new Dictionary + {{"n_vocab", 0}, {"n_ctx", 1024}, {"n_embd", 768}, {"n_head", 12}, {"n_layer", 12}}; + } + + public static object Model(IDictionary hParams, Tensor x, Tensor past, string scope, _ReuseMode autoReuse) + { + using (var varScope = tf.variable_scope(name: scope, reuse: autoReuse == _ReuseMode.AUTO_REUSE)) + { + var results = new Dictionary(); + (int batch, Tensor sequence) = ShapeList(x); + + var wpe = tf.get_variable("wpe", new TensorShape(hParams["n_ctx"], hParams["n_embd"]), + initializer: tf.random_normal_initializer(stddev: 0.01F)); + + var wte = tf.get_variable("wte", new TensorShape(hParams["n_vocab"], hParams["n_embd"]), + initializer: tf.random_normal_initializer(stddev: 0.02F)); + + var past_length = past == null ? new Tensor(0) : tf.shape(past)[past.shape.Length - 1]; + + var h = tf.gather(wte, x) + tf.gather(wpe, PositionsFor(x, past_length)); + } + } + + private static Tensor PositionsFor(Tensor tokens, Tensor pastLength) + { + var batch_size = tf.shape(tokens)[0]; + var nSteps = tf.shape(tokens)[1]; + return ExpandTile(pastLength + tf.range(nSteps), batch_size); + } + + private static Tensor ExpandTile(Tensor value, Tensor size) + { + var outValue = tf.convert_to_tensor(value, name: "value"); + var nDims = outValue.shape.Length; + var multiples = new Tensor() + return tf.tile(tf.expand_dims(value, axis: 0), size + new Tensor(1) * nDims); + } + + private static (int batch, Tensor sequence) ShapeList(Tensor x) + { + return (x.shape[0], tf.shape(x)[1]); + } + } +} diff --git a/ports/win32-csharp/src/POC.csproj b/ports/win32-csharp/src/POC.csproj new file mode 100644 index 00000000..be922ece --- /dev/null +++ b/ports/win32-csharp/src/POC.csproj @@ -0,0 +1,190 @@ + + + + + + + + Debug + AnyCPU + {F41FBFAD-36BE-4C89-A6E2-161391063FE1} + WinExe + AIDungeon.net + AIDungeon.net + v4.7.1 + 512 + true + true + + + + + AnyCPU + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + + + AnyCPU + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + + + true + bin\x64\Debug\ + DEBUG;TRACE + full + x64 + prompt + MinimumRecommendedRules.ruleset + true + + + bin\x64\Release\ + TRACE + true + pdbonly + x64 + prompt + MinimumRecommendedRules.ruleset + true + + + + ..\packages\Google.Protobuf.3.11.1\lib\net45\Google.Protobuf.dll + + + ..\packages\Microsoft.ML.1.4.0\lib\netstandard2.0\Microsoft.ML.Core.dll + + + ..\packages\Microsoft.ML.CpuMath.1.4.0\lib\netstandard2.0\Microsoft.ML.CpuMath.dll + + + ..\packages\Microsoft.ML.1.4.0\lib\netstandard2.0\Microsoft.ML.Data.dll + + + ..\packages\Microsoft.ML.DataView.1.4.0\lib\netstandard2.0\Microsoft.ML.DataView.dll + + + ..\packages\Microsoft.ML.1.4.0\lib\netstandard2.0\Microsoft.ML.KMeansClustering.dll + + + ..\packages\Microsoft.ML.1.4.0\lib\netstandard2.0\Microsoft.ML.PCA.dll + + + ..\packages\Microsoft.ML.1.4.0\lib\netstandard2.0\Microsoft.ML.StandardTrainers.dll + + + ..\packages\Microsoft.ML.TensorFlow.1.4.0\lib\netstandard2.0\Microsoft.ML.TensorFlow.dll + + + ..\packages\Microsoft.ML.1.4.0\lib\netstandard2.0\Microsoft.ML.Transforms.dll + + + ..\packages\Newtonsoft.Json.12.0.3\lib\net45\Newtonsoft.Json.dll + + + ..\packages\NumSharp.0.20.4\lib\netstandard2.0\NumSharp.Core.dll + + + + ..\packages\System.Buffers.4.5.0\lib\netstandard2.0\System.Buffers.dll + + + ..\packages\System.CodeDom.4.7.0\lib\net461\System.CodeDom.dll + + + ..\packages\System.Collections.Immutable.1.7.0\lib\netstandard2.0\System.Collections.Immutable.dll + + + + ..\packages\System.IO.FileSystem.AccessControl.4.7.0\lib\net461\System.IO.FileSystem.AccessControl.dll + + + ..\packages\System.Memory.4.5.3\lib\netstandard2.0\System.Memory.dll + + + + ..\packages\System.Numerics.Vectors.4.5.0\lib\net46\System.Numerics.Vectors.dll + + + ..\packages\System.Runtime.CompilerServices.Unsafe.4.7.0\lib\netstandard2.0\System.Runtime.CompilerServices.Unsafe.dll + + + ..\packages\System.Security.AccessControl.4.7.0\lib\net461\System.Security.AccessControl.dll + + + ..\packages\System.Security.Principal.Windows.4.7.0\lib\net461\System.Security.Principal.Windows.dll + + + ..\packages\System.Threading.Tasks.Dataflow.4.11.0\lib\netstandard2.0\System.Threading.Tasks.Dataflow.dll + + + + + + + + + + + + ..\packages\TensorFlow.NET.0.13.0\lib\netstandard2.0\TensorFlow.NET.dll + + + + + + Form + + + Form1.cs + + + + + + + + ResXFileCodeGenerator + Resources.Designer.cs + Designer + + + True + Resources.resx + + + + SettingsSingleFileGenerator + Settings.Designer.cs + + + True + Settings.settings + True + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + \ No newline at end of file diff --git a/ports/win32-csharp/src/Program.cs b/ports/win32-csharp/src/Program.cs new file mode 100644 index 00000000..bda8d8bf --- /dev/null +++ b/ports/win32-csharp/src/Program.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using System.Windows.Forms; + +namespace AIDungeon.net +{ + static class Program + { + /// + /// The main entry point for the application. + /// + [STAThread] + static void Main() + { + Application.EnableVisualStyles(); + Application.SetCompatibleTextRenderingDefault(false); + Application.Run(new Form1()); + } + } +} diff --git a/ports/win32-csharp/src/Properties/AssemblyInfo.cs b/ports/win32-csharp/src/Properties/AssemblyInfo.cs new file mode 100644 index 00000000..5e13bb70 --- /dev/null +++ b/ports/win32-csharp/src/Properties/AssemblyInfo.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("AIDungeon.net")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("AIDungeon.net")] +[assembly: AssemblyCopyright("Copyright © 2019")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("f41fbfad-36be-4c89-a6e2-161391063fe1")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("1.0.*")] +[assembly: AssemblyVersion("1.0.0.0")] +[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/ports/win32-csharp/src/Properties/Resources.Designer.cs b/ports/win32-csharp/src/Properties/Resources.Designer.cs new file mode 100644 index 00000000..c999a2b2 --- /dev/null +++ b/ports/win32-csharp/src/Properties/Resources.Designer.cs @@ -0,0 +1,71 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace AIDungeon.net.Properties +{ + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class Resources + { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal Resources() + { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager + { + get + { + if ((resourceMan == null)) + { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("AIDungeon.net.Properties.Resources", typeof(Resources).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture + { + get + { + return resourceCulture; + } + set + { + resourceCulture = value; + } + } + } +} diff --git a/ports/win32-csharp/src/Properties/Resources.resx b/ports/win32-csharp/src/Properties/Resources.resx new file mode 100644 index 00000000..af7dbebb --- /dev/null +++ b/ports/win32-csharp/src/Properties/Resources.resx @@ -0,0 +1,117 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + \ No newline at end of file diff --git a/ports/win32-csharp/src/Properties/Settings.Designer.cs b/ports/win32-csharp/src/Properties/Settings.Designer.cs new file mode 100644 index 00000000..9fe1a622 --- /dev/null +++ b/ports/win32-csharp/src/Properties/Settings.Designer.cs @@ -0,0 +1,30 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace AIDungeon.net.Properties +{ + + + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.Editors.SettingsDesigner.SettingsSingleFileGenerator", "11.0.0.0")] + internal sealed partial class Settings : global::System.Configuration.ApplicationSettingsBase + { + + private static Settings defaultInstance = ((Settings)(global::System.Configuration.ApplicationSettingsBase.Synchronized(new Settings()))); + + public static Settings Default + { + get + { + return defaultInstance; + } + } + } +} diff --git a/ports/win32-csharp/src/Properties/Settings.settings b/ports/win32-csharp/src/Properties/Settings.settings new file mode 100644 index 00000000..39645652 --- /dev/null +++ b/ports/win32-csharp/src/Properties/Settings.settings @@ -0,0 +1,7 @@ + + + + + + + diff --git a/ports/win32-csharp/src/Sample.cs b/ports/win32-csharp/src/Sample.cs new file mode 100644 index 00000000..9ae442cd --- /dev/null +++ b/ports/win32-csharp/src/Sample.cs @@ -0,0 +1,47 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; + +namespace AIDungeon.net +{ + public class Sample + { + + public static Tensor SampleSequence(IDictionary hParams, int length, Tensor startToken, + int batchSize, Tensor context, decimal temperature, int topK, int topP) + { + if (startToken == null) + { + if (context == null) throw new ArgumentNullException(nameof(context), $"Specify exactly one of {nameof(startToken)} and {nameof(context)}"); + } + else + { + if (context != null) throw new ArgumentException($"Specify exactly one of {nameof(startToken)} and {nameof(context)}", nameof(context)); + context = tf.fill(new Tensor(new[]{batchSize, 1}), startToken); + } + + using (var ns = tf.name_scope("sample_sequence")) + { + var initialState = Body((past: null, prev: context, output: context)); + var tokens = control_flow_ops.while_loop<(Tensor past, Tensor prev, Tensor output)>(Cond, Body,) + } + } + + private static (object logits, object presents) Step(IDictionary hParams, Tensor tokens, + Tensor past = null) + { + var lm_output = GPT2Model.Model(hParams, tokens, past, "model", _ReuseMode.AUTO_REUSE); + } + + private static Tensor Cond((Tensor past, Tensor prev, Tensor output) _) => new Tensor(true); + + private static (Tensor past, Tensor prev, Tensor output) Body((Tensor past, Tensor prev, Tensor output) state) + { + + } + } +} diff --git a/ports/win32-csharp/src/packages.config b/ports/win32-csharp/src/packages.config new file mode 100644 index 00000000..b3abf7b2 --- /dev/null +++ b/ports/win32-csharp/src/packages.config @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file