Skip to content

Commit

Permalink
Fix invalid module name with dash
Browse files Browse the repository at this point in the history
  • Loading branch information
sanders41 committed Dec 19, 2023
1 parent fe133be commit 1c4206c
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 46 deletions.
3 changes: 2 additions & 1 deletion src/file_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ pub fn save_file_with_content(file_path: &PathBuf, file_content: &str) -> Result
}

pub fn save_empty_src_file(project_info: &ProjectInfo, file_name: &str) -> Result<()> {
let module = project_info.source_dir.replace('-', "_");
let file_path = project_info
.base_dir()
.join(format!("{}/{}", &project_info.source_dir, file_name));
.join(format!("{}/{}", &module, file_name));
File::create(file_path)?;

Ok(())
Expand Down
25 changes: 14 additions & 11 deletions src/project_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ use crate::python_files::generate_python_files;
use crate::rust_files::{save_cargo_toml_file, save_lib_file};

fn create_directories(project_info: &ProjectInfo) -> Result<()> {
let module = project_info.source_dir.replace('-', "_");
let base = project_info.base_dir();
let src = base.join(&project_info.source_dir);
let src = base.join(&module);
create_dir_all(src)?;

let github_dir = base.join(".github/workflows");
Expand Down Expand Up @@ -330,6 +331,7 @@ fn build_latest_dev_dependencies(
}

fn create_pyproject_toml(project_info: &ProjectInfo) -> String {
let module = project_info.source_dir.replace('-', "_");
let pyupgrade_version = &project_info.min_python_version.replace(['.', '^'], "");
let license_text = license_str(&project_info.license);
let mut pyproject = match &project_info.project_manager {
Expand All @@ -347,7 +349,7 @@ license = "{{ license }}"
readme = "README.md"
[tool.maturin]
module-name = "{{ source_dir }}._{{ source_dir }}"
module-name = "{{ module }}._{{ module }}"
binding = "pyo3"
features = ["pyo3/extension-module"]
Expand Down Expand Up @@ -392,14 +394,14 @@ requires-python = ">={{ min_python_version }}"
dynamic = ["version", "readme"]
[tool.setuptools.dynamic]
version = {attr = "{{ source_dir }}.__version__"}
version = {attr = "{{ module }}.__version__"}
readme = {file = ["README.md"]}
[tool.setuptools.packages.find]
include = ["{{ source_dir }}*"]
include = ["{{ module }}*"]
[tool.setuptools.package-data]
{{ source_dir }} = ["py.typed"]
{{ module }} = ["py.typed"]
"#
.to_string(),
Expand All @@ -416,7 +418,7 @@ disallow_untyped_defs = false
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "--cov={{ source_dir }} --cov-report term-missing --no-cov-on-fail"
addopts = "--cov={{ module }} --cov-report term-missing --no-cov-on-fail"
[tool.coverage.report]
exclude_lines = ["if __name__ == .__main__.:", "pragma: no cover"]
Expand Down Expand Up @@ -450,7 +452,7 @@ fix = true

render!(
&pyproject,
project_name => project_info.source_dir.replace('_', "-"),
project_name => module.replace('_', "-"),
version => project_info.version,
project_description => project_info.project_description,
creator => project_info.creator,
Expand All @@ -459,7 +461,7 @@ fix = true
min_python_version => project_info.min_python_version,
dev_dependencies => build_latest_dev_dependencies(project_info.is_application, project_info.download_latest_packages, &project_info.project_manager, &project_info.min_python_version),
max_line_length => project_info.max_line_length,
source_dir => project_info.source_dir,
module => module,
is_application => project_info.is_application,
pyupgrade_version => pyupgrade_version,
)
Expand Down Expand Up @@ -488,7 +490,7 @@ fn save_dev_requirements(project_info: &ProjectInfo) -> Result<()> {
Ok(())
}

fn create_pyo3_justfile(source_dir: &str) -> String {
fn create_pyo3_justfile(module: &str) -> String {
format!(
r#"@develop:
maturin develop
Expand Down Expand Up @@ -531,13 +533,14 @@ fn create_pyo3_justfile(source_dir: &str) -> String {
@test:
pytest
"#,
source_dir
module
)
}

fn save_pyo3_justfile(project_info: &ProjectInfo) -> Result<()> {
let module = project_info.source_dir.replace('-', "_");
let file_path = project_info.base_dir().join("justfile");
let content = create_pyo3_justfile(&project_info.source_dir);
let content = create_pyo3_justfile(&module);

save_file_with_content(&file_path, &content)?;

Expand Down
5 changes: 4 additions & 1 deletion src/project_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ pub fn get_project_info(use_defaults: bool) -> Result<ProjectInfo> {
bail!(format!("The {project_slug} directory already exists"));
}

let source_dir_default = project_name.replace(' ', "_").to_lowercase();
let source_dir_default = project_name
.replace(' ', "_")
.replace('-', "_")
.to_lowercase();
let source_dir = if use_defaults {
source_dir_default
} else {
Expand Down
68 changes: 37 additions & 31 deletions src/python_files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use anyhow::{bail, Result};
use crate::file_manager::save_file_with_content;
use crate::project_info::{ProjectInfo, ProjectManager};

fn create_dunder_main_file(source_dir: &str) -> String {
fn create_dunder_main_file(module: &str) -> String {
format!(
r#"from {source_dir}.main import main # pragma: no cover
r#"from {module}.main import main # pragma: no cover
if __name__ == "__main__":
raise SystemExit(main())
Expand All @@ -29,23 +29,24 @@ if __name__ == "__main__":
}

fn save_main_files(project_info: &ProjectInfo) -> Result<()> {
let src = project_info.base_dir().join(&project_info.source_dir);
let module = project_info.source_dir.replace('-', "_");
let src = project_info.base_dir().join(&module);
let main = src.join("main.py");
let main_content = create_main_file();

save_file_with_content(&main, &main_content)?;

let main_dunder = src.join("__main__.py");
let main_dunder_content = create_dunder_main_file(&project_info.source_dir);
let main_dunder_content = create_dunder_main_file(&module);

save_file_with_content(&main_dunder, &main_dunder_content)?;

Ok(())
}

fn create_main_test_file(source_dir: &str) -> String {
fn create_main_test_file(module: &str) -> String {
format!(
r#"from {source_dir}.main import main
r#"from {module}.main import main
def test_main():
Expand All @@ -55,17 +56,18 @@ def test_main():
}

fn save_main_test_file(project_info: &ProjectInfo) -> Result<()> {
let module = project_info.source_dir.replace('-', "_");
let file_path = project_info.base_dir().join("tests/test_main.py");
let content = create_main_test_file(&project_info.source_dir);
let content = create_main_test_file(&module);

save_file_with_content(&file_path, &content)?;

Ok(())
}

fn create_pyo3_test_file(source_dir: &str) -> String {
fn create_pyo3_test_file(module: &str) -> String {
format!(
r#"from {source_dir} import sum_as_string
r#"from {module} import sum_as_string
def test_sum_as_string():
Expand All @@ -75,25 +77,26 @@ def test_sum_as_string():
}

fn save_pyo3_test_file(project_info: &ProjectInfo) -> Result<()> {
let module = project_info.source_dir.replace('-', "_");
let file_path = project_info
.base_dir()
.join(format!("tests/test_{}.py", &project_info.source_dir));
let content = create_pyo3_test_file(&project_info.source_dir);
.join(format!("tests/test_{}.py", &module));
let content = create_pyo3_test_file(&module);

save_file_with_content(&file_path, &content)?;

Ok(())
}

fn create_project_init_file(source_dir: &str, project_manager: &ProjectManager) -> String {
fn create_project_init_file(module: &str, project_manager: &ProjectManager) -> String {
match project_manager {
ProjectManager::Maturin => {
let v_ascii: u8 = 118;
if let Some(first_char) = source_dir.chars().next() {
if let Some(first_char) = module.chars().next() {
if (first_char as u8) < v_ascii {
format!(
r#"from {source_dir}._{source_dir} import sum_as_string
from {source_dir}._version import VERSION
r#"from {module}._{module} import sum_as_string
from {module}._version import VERSION
__version__ = VERSION
Expand All @@ -103,8 +106,8 @@ __all__ = ["sum_as_string"]
)
} else {
format!(
r#"from {source_dir}._version import VERSION
from {source_dir}._{source_dir} import sum_as_string
r#"from {module}._version import VERSION
from {module}._{module} import sum_as_string
__version__ = VERSION
Expand All @@ -115,8 +118,8 @@ __all__ = ["sum_as_string"]
}
} else {
format!(
r#"from {source_dir}._{source_dir} import sum_as_string
r#"from {source_dir}._version import VERSION
r#"from {module}._{module} import sum_as_string
r#"from {module}._version import VERSION
__version__ = VERSION
Expand All @@ -128,7 +131,7 @@ __all__ = ["sum_as_string"]
}
_ => {
format!(
r#"from {source_dir}._version import VERSION
r#"from {module}._version import VERSION
__version__ = VERSION
"#
Expand All @@ -145,10 +148,11 @@ fn save_test_init_file(project_info: &ProjectInfo) -> Result<()> {
}

fn save_project_init_file(project_info: &ProjectInfo) -> Result<()> {
let module = project_info.source_dir.replace('-', "_");
let file_path = project_info
.base_dir()
.join(format!("{}/__init__.py", &project_info.source_dir));
let content = create_project_init_file(&project_info.source_dir, &project_info.project_manager);
.join(format!("{}/__init__.py", &module));
let content = create_project_init_file(&module, &project_info.project_manager);

save_file_with_content(&file_path, &content)?;

Expand All @@ -164,10 +168,10 @@ def sum_as_string(a: int, b: int) -> str: ...
}

pub fn save_pyi_file(project_info: &ProjectInfo) -> Result<()> {
let file_path = project_info.base_dir().join(format!(
"{}/_{}.pyi",
&project_info.source_dir, &project_info.source_dir
));
let module = project_info.source_dir.replace('-', "_");
let file_path = project_info
.base_dir()
.join(format!("{}/_{}.pyi", &module, &module));
let content = create_pyi_file();

save_file_with_content(&file_path, &content)?;
Expand All @@ -180,17 +184,18 @@ fn create_version_file(version: &str) -> String {
}

fn save_version_file(project_info: &ProjectInfo) -> Result<()> {
let module = project_info.source_dir.replace('-', "_");
let file_path = project_info
.base_dir()
.join(format!("{}/_version.py", &project_info.source_dir));
.join(format!("{}/_version.py", &module));
let content = create_version_file(&project_info.version);

save_file_with_content(&file_path, &content)?;

Ok(())
}

fn create_version_test_file(source_dir: &str, project_manager: &ProjectManager) -> String {
fn create_version_test_file(module: &str, project_manager: &ProjectManager) -> String {
let version_test: &str = match project_manager {
ProjectManager::Maturin => {
r#"def test_versions_match():
Expand All @@ -212,7 +217,7 @@ fn create_version_test_file(source_dir: &str, project_manager: &ProjectManager)
}
ProjectManager::Setuptools => {
return format!(
r#"from {source_dir}._version import VERSION
r#"from {module}._version import VERSION
def test_versions_match():
assert VERSION == "0.1.0"
Expand All @@ -225,7 +230,7 @@ def test_versions_match():
r#"import sys
from pathlib import Path
from {source_dir}._version import VERSION
from {module}._version import VERSION
if sys.version_info < (3, 11):
import tomli as tomllib
Expand All @@ -239,8 +244,9 @@ else:
}

fn save_version_test_file(project_info: &ProjectInfo) -> Result<()> {
let module = project_info.source_dir.replace('-', "_");
let file_path = project_info.base_dir().join("tests/test_version.py");
let content = create_version_test_file(&project_info.source_dir, &project_info.project_manager);
let content = create_version_test_file(&module, &project_info.project_manager);
save_file_with_content(&file_path, &content)?;

Ok(())
Expand Down
6 changes: 4 additions & 2 deletions src/rust_files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ fn create_cargo_toml_file(
) -> String {
let versions = build_latest_dependencies(min_python_version, download_latest_packages);
let license = license_str(license_type);
let name = source_dir.replace('-', "_");

format!(
r#"[package]
Expand All @@ -69,7 +70,7 @@ license = "{license}"
readme = "README.md"
[lib]
name = "_{source_dir}"
name = "_{name}"
crate-type = ["cdylib"]
[dependencies]
Expand All @@ -95,6 +96,7 @@ pub fn save_cargo_toml_file(project_info: &ProjectInfo) -> Result<()> {
}

fn create_lib_file(source_dir: &str) -> String {
let module = source_dir.replace('-', "_");
format!(
r#"use pyo3::prelude::*;
Expand All @@ -104,7 +106,7 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {{
}}
#[pymodule]
fn _{source_dir}(_py: Python, m: &PyModule) -> PyResult<()> {{
fn _{module}(_py: Python, m: &PyModule) -> PyResult<()> {{
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
Ok(())
}}
Expand Down

0 comments on commit 1c4206c

Please sign in to comment.