{ "cells": [ { "cell_type": "markdown", "id": "682a55f6-9ecd-413d-bb53-a56bee7a5650", "metadata": {}, "source": [ "To enable interactive plots, run `jupyter labextension install jupyter-matplotlib` followed by:" ] }, { "cell_type": "code", "execution_count": 1, "id": "bee15f5f-a563-4f82-922e-72b2fabaf16a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Building jupyterlab assets (production, minimized)\n" ] } ], "source": [ "!jupyter labextension install jupyter-matplotlib" ] }, { "cell_type": "code", "execution_count": 2, "id": "cb59ab52-9a6d-497c-891b-5432abc6a6a4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install --upgrade ipympl matplotlib -q\n", "%matplotlib ipympl" ] }, { "cell_type": "code", "execution_count": 3, "id": "616a153d-bcec-4a1e-8836-d9f46015c2d3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install hdbscan pymatgen colorcet -q" ] }, { "cell_type": "code", "execution_count": 4, "id": "0be7ce91-2cc8-4c8a-9029-a0149f72a6a3", "metadata": { "tags": [] }, "outputs": [], "source": [ "import hdbscan\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "from sklearn import manifold\n", "from ipywidgets import interact, Output\n", "from IPython.display import clear_output\n", "\n", "import sys\n", "sys.path.append('..')\n", "sys.path.append('../autoencoder')\n", "\n", "from src.band_plotters import*\n", "from src.TensorImageNoised import *\n", "from src.cluster_plotters import *\n", "sys.path.append('/notebooks/band-fingerprint/autoencoder/')\n", "sys.path.append('/notebooks/band-fingerprint/src')\n", "\n", "from model import *\n", "import model as resnet_autoencoder\n", "\n", "from ae_misc import *\n", "\n", "%matplotlib inline\n", "#%matplotlib auto" ] }, { "cell_type": "markdown", "id": "743c187b-af62-4fae-bb20-06f4c47d47fe", "metadata": { "tags": [] }, "source": [ "# Enter Full Fingerprint Name - Specify Configuration Below\n", "Note perplexity can be changed for second tsne run - see configuration. See `fingerprints` folder to view possible `FINGERPRINT_NAME`'s." ] }, { "cell_type": "code", "execution_count": 5, "id": "b25a7b2d-fb88-4ec7-a5cf-e1086466ff0b", "metadata": {}, "outputs": [], "source": [ "FINGERPRINT_NAME = \"224_2channel_resnet_L=98_perplexity_30_length_98\"" ] }, { "cell_type": "markdown", "id": "ac09a9d1-2d3d-4a51-b406-fa3efa829c31", "metadata": { "tags": [] }, "source": [ "# Functions to View Materials/Clusters" ] }, { "cell_type": "markdown", "id": "aefe936b-50ff-48e4-8521-9f030d898ceb", "metadata": {}, "source": [ "For any fingerprint" ] }, { "cell_type": "code", "execution_count": 6, "id": "3a0547ef-6998-4b67-a041-dc34238361cf", "metadata": {}, "outputs": [], "source": [ "def view_material_any(index, show_inp=False):\n", " fig, ax = plt.subplots()\n", " ax.plot(np.linspace(0, fp_dict[\"length\"], fp_dict[\"length\"]), df.loc[index][fingerprint_cols])\n", " ax.set_xlabel(index)\n", " ax.set_title(\"fingerprint\")\n", " plt.show()\n", " \n", " if show_inp:\n", " image = Image.open(DATA_DIRECTORY/f\"images/grayscale_4ev_linewidth3/{index}.png\")\n", " image.show()" ] }, { "cell_type": "markdown", "id": "1cab76e1-0b05-417f-a8dc-c7dfa999b962", "metadata": {}, "source": [ "For an autoencoder based fingerprint" ] }, { "cell_type": "code", "execution_count": 7, "id": "b2c20362-53a6-4d14-a31b-100036a2887a", "metadata": {}, "outputs": [], "source": [ "def show_tensor_image(tensor, ax=None):\n", " if ax is None:\n", " fig, ax = plt.subplots()\n", " ax.imshow(tensor.permute(1, 2,0).detach().numpy(), vmax=1, vmin=0)\n", " return ax\n", "\n", "def view_material_ae_2_channel(index, inp_recon_together=False):\n", " view_material_any(index)\n", " \n", " print(df.loc[index][\"member_strength\"])\n", "\n", " image_filename = DATA_DIRECTORY/f\"images/grayscale_4ev_linewidth3/{index}.png\"\n", "\n", " if learn:\n", " # Get input and prediction\n", " _, out, _ = learn.predict(image_filename)\n", "\n", "\n", " # 2 channel image\n", " dl = learn.dls.test_dl([image_filename])\n", "\n", " with torch.no_grad():\n", " inp = dl.one_batch()[0]\n", " fingerprint_2_channel = learn.model.encoder(inp)\n", "\n", " fingerprint_2_channel = fingerprint_2_channel[0]\n", " fingerprint_2_channel /= 3.982462167739868 # max activation in fingerprint array\n", " rgb_image = torch.stack([fingerprint_2_channel[0, :, :], torch.zeros_like(fingerprint_2_channel[0, :, :]), fingerprint_2_channel[1, :, :]], dim=0)\n", " show_tensor_image(rgb_image)\n", "\n", " ax = None\n", " if inp_recon_together:\n", " fig, ax = plt.subplots(2, 1)\n", "\n", " show_tensor_image(inp[0], ax)\n", " show_tensor_image(torch.sigmoid(out), ax)\n", " plt.show()\n", " else:\n", " print(f\"Model not loaded, displaying bandstructure of {index}:\")\n", " show_image(load_image(image_filename))\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "a8fe6c6f-cf5b-47df-aae6-6cf66f7e0ae3", "metadata": {}, "outputs": [], "source": [ "def view_cluster(label):\n", " display(df[df.labels==label].drop(BORING_COLUMNS, axis=1).head(200))\n", "\n", " num_plots = len(df[df.labels==label])\n", "\n", " #fig, ax = plt.subplots(num_plots, 1, figsize=(4, 1*num_plots))\n", "\n", " for i, index in enumerate(df[df.labels==label].index):\n", " view_material_function(index)" ] }, { "cell_type": "markdown", "id": "24fe7559-a55f-4f94-9027-852a27843075", "metadata": { "tags": [] }, "source": [ "# Fingerprint Configurations" ] }, { "cell_type": "code", "execution_count": 9, "id": "671c4989-f5ea-4e3b-b441-30c87d20ffe0", "metadata": {}, "outputs": [], "source": [ "FINGERPRINTS = {\n", " \"224_2channel_resnet_L=98_perplexity_30_length_98\":{\n", " \"hdbscan\":{\"min_cluster_size\":5, \"min_samples\":2},\n", " \"tsne\":{\"perplexity\":10, \"early_exaggeration\":12.0, \"learning_rate\":\"auto\"},\n", " \"length\":98,\n", " \"model\": \"../autoencoder/trained_models/resnet18_size224_lossbce_channels2.pkl\",\n", " \"view_func\":view_material_ae_2_channel\n", " },\n", " \n", " \"all_k_branches_histogram_-8_to_8_perplexity_30_length_60\":{\n", " \"hdbscan\":{\"min_cluster_size\":4, \"min_samples\":3},\n", " \"tsne\":{\"perplexity\":30, \"early_exaggeration\":12.0, \"learning_rate\":\"auto\"},\n", " \"length\":60,\n", " \"model\":False,\n", " \"view_func\":lambda mat_id: view_material_any(mat_id, True) \n", " }\n", " \n", "}\n", "\n", "fp_dict = FINGERPRINTS[FINGERPRINT_NAME]\n", "hdbscan_dict = fp_dict[\"hdbscan\"]\n", "tsne_dict = fp_dict[\"tsne\"]\n", "length = fp_dict[\"length\"]\n", "view_material_function = fp_dict[\"view_func\"]\n", "\n", "\n", "FLAT_ONLY = True\n", "BORING_COLUMNS = [\"flat_segments\", \"flatness_score\", \"binary_flatness\", \"horz_flat_seg\", \"exfoliation_eg\", \"A\", \"B\", \"C\", \"D\", \"E\", \"F\"]\n", "INPUT_NAME = FINGERPRINT_NAME + \".csv\"" ] }, { "cell_type": "code", "execution_count": 10, "id": "3d32fd3c-7b55-41e6-a37c-e37dc8f29912", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model not found\n" ] } ], "source": [ "try:\n", " if fp_dict[\"model\"]:\n", " learn = load_learner(fp_dict[\"model\"])\n", "except:\n", " learn = None\n", " print(\"Model not found\")" ] }, { "cell_type": "markdown", "id": "81594502-7893-4d49-8df3-d88932fd59c9", "metadata": { "tags": [] }, "source": [ "# Load Data" ] }, { "cell_type": "code", "execution_count": 11, "id": "55431b8d-b4ed-4ac3-b6ca-cbd762fa8d74", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
| \n", " | formula | \n", "gen_formula | \n", "space_group | \n", "segments | \n", "flat_segments | \n", "flatness_score | \n", "discovery | \n", "binary_flatness | \n", "horz_flat_seg | \n", "exfoliation_eg | \n", "... | \n", "90 | \n", "91 | \n", "92 | \n", "93 | \n", "94 | \n", "95 | \n", "96 | \n", "97 | \n", "fx | \n", "fy | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| ID | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
| 2dm-3 | \n", "TlS | \n", "AB | \n", "2 | \n", "4 | \n", "4 | \n", "0.84646 | \n", "bottom-up | \n", "1 | \n", "3 | \n", "0.095794 | \n", "... | \n", "2.325100 | \n", "1.733577 | \n", "1.816098 | \n", "1.953408 | \n", "1.904952 | \n", "1.718350 | \n", "1.920829 | \n", "1.940830 | \n", "-17.031164 | \n", "-22.583645 | \n", "
| 2dm-21 | \n", "TaI3 | \n", "AB3 | \n", "162 | \n", "3 | \n", "3 | \n", "0.88201 | \n", "bottom-up | \n", "1 | \n", "1 | \n", "0.097255 | \n", "... | \n", "2.226489 | \n", "2.289640 | \n", "2.589102 | \n", "2.477097 | \n", "2.594307 | \n", "2.217409 | \n", "2.636327 | \n", "2.281451 | \n", "-85.320190 | \n", "19.231680 | \n", "
| 2dm-22 | \n", "Li2O | \n", "AB2 | \n", "164 | \n", "3 | \n", "3 | \n", "0.96678 | \n", "bottom-up | \n", "1 | \n", "4 | \n", "0.037593 | \n", "... | \n", "2.502424 | \n", "1.772349 | \n", "1.775590 | \n", "1.941162 | \n", "1.745729 | \n", "1.716548 | \n", "1.714002 | \n", "1.817789 | \n", "74.049220 | \n", "-91.258260 | \n", "
| 2dm-25 | \n", "VBr4 | \n", "AB4 | \n", "123 | \n", "3 | \n", "3 | \n", "0.97834 | \n", "bottom-up | \n", "1 | \n", "2 | \n", "0.140290 | \n", "... | \n", "2.372130 | \n", "2.224227 | \n", "2.399812 | \n", "2.197275 | \n", "2.769758 | \n", "2.584419 | \n", "2.809708 | \n", "2.114330 | \n", "36.688946 | \n", "55.726463 | \n", "
| 2dm-29 | \n", "SBr | \n", "AB | \n", "2 | \n", "4 | \n", "4 | \n", "0.82037 | \n", "bottom-up | \n", "1 | \n", "3 | \n", "0.067035 | \n", "... | \n", "2.852715 | \n", "2.168420 | \n", "2.470804 | \n", "2.357963 | \n", "2.278249 | \n", "2.280239 | \n", "2.296243 | \n", "2.235125 | \n", "11.995536 | \n", "-106.832634 | \n", "
5 rows × 126 columns
\n", "